YAML Metadata Warning: empty or missing yaml metadata in repo card
Check out the documentation for more information.
Requirements
torch
torchvision
pillow
kornia
transformers
Usage
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForImageSegmentation.from_pretrained('briaai/VRMBG-2.0', trust_remote_code=True).eval().half().to(device)
model = torch.compile(model, dynamic=False, backend="inductor")
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# load inputs
image_prev_frame_path = "" # fill path
image_current_frame_path = "" # fill path
image_prev_frame = Image.open(image_prev_frame_path)
image_current_frame = Image.open(image_current_frame_path)
input_frames = torch.cat([transform_image(image_prev_frame), transform_image(image_current_frame)], dim=0)
input_frames = input_frames.unsqueeze(0).to(device).half()
# Prediction
with torch.no_grad():
preds = model(input_frames).cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_current_frame.size)
# save result
image_current_frame.putalpha(mask)
image_current_frame.save("image_current_frame_no_background.png")
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support