Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModel | |
| from pathlib import Path | |
| import numpy as np | |
| from decord import VideoReader | |
| import imageio | |
| FRAME_SAMPLING_RATE = 4 | |
| DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" | |
| processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) | |
| model = AutoModel.from_pretrained(DEFAULT_MODEL) | |
| ROOMS = ( | |
| "bathroom,sauna,living room, bedroom,kitchen,toilet,hallway,dressing,attic,basement" | |
| ) | |
| examples = [ | |
| [ | |
| "movies/bathroom.mp4", | |
| ROOMS, | |
| ], | |
| ] | |
| def sample_frames_from_video_file( | |
| file_path: str, num_frames: int = 16, frame_sampling_rate=1 | |
| ): | |
| videoreader = VideoReader(file_path) | |
| videoreader.seek(0) | |
| # sample frames | |
| start_idx = 0 | |
| end_idx = num_frames * frame_sampling_rate - 1 | |
| indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64) | |
| frames = videoreader.get_batch(indices).asnumpy() | |
| return frames | |
| def get_num_total_frames(file_path: str): | |
| videoreader = VideoReader(file_path) | |
| videoreader.seek(0) | |
| return len(videoreader) | |
| # def convert_frames_to_gif(frames, save_path: str = "frames.gif"): | |
| # converted_frames = frames.astype(np.uint8) | |
| # Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| # imageio.mimsave(save_path, converted_frames, fps=8) | |
| # return save_path | |
| # def create_gif_from_video_file( | |
| # file_path: str, | |
| # num_frames: int = 16, | |
| # frame_sampling_rate: int = 1, | |
| # save_path: str = "frames.gif", | |
| # ): | |
| # frames = sample_frames_from_video_file(file_path, num_frames, frame_sampling_rate) | |
| # return convert_frames_to_gif(frames, save_path) | |
| def select_model(model_name): | |
| global processor, model | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name) | |
| def get_frame_sampling_rate(video_path, num_model_input_frames): | |
| # rearrange sampling rate based on video length and model input length | |
| num_total_frames = get_num_total_frames(video_path) | |
| if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: | |
| frame_sampling_rate = num_total_frames // num_model_input_frames | |
| else: | |
| frame_sampling_rate = FRAME_SAMPLING_RATE | |
| return frame_sampling_rate | |
| def predict(video_path, labels_text): | |
| labels = labels_text.split(",") | |
| num_model_input_frames = model.config.vision_config.num_frames | |
| frame_sampling_rate = get_frame_sampling_rate(video_path, num_model_input_frames) | |
| frames = sample_frames_from_video_file( | |
| video_path, num_model_input_frames, frame_sampling_rate | |
| ) | |
| # gif_path = convert_frames_to_gif(frames, save_path="video.gif") | |
| inputs = processor( | |
| text=labels, videos=list(frames), return_tensors="pt", padding=True | |
| ) | |
| # forward pass | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() | |
| label_to_prob = {} | |
| for ind, label in enumerate(labels): | |
| label_to_prob[label] = float(probs[ind]) | |
| # return label_to_prob, gif_path | |
| return label_to_prob | |
| app = gr.Blocks() | |
| with app: | |
| gr.Markdown( | |
| "# **<p align='center'>Classification of Rooms</p>**" | |
| ) | |
| gr.Markdown( | |
| "### **<p align='center'>Upload a video of a room and provide a list of type of rooms the model should select from.</p>**" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_file = gr.Video(label="Video File:", show_label=True) | |
| local_video_labels_text = gr.Textbox( | |
| label="Labels Text:", show_label=True | |
| ) | |
| submit_button = gr.Button(value="Predict") | |
| # with gr.Column(): | |
| # video_gif = gr.Image( | |
| # label="Input Clip", | |
| # show_label=True, | |
| # ) | |
| with gr.Column(): | |
| predictions = gr.Label(label="Predictions:", show_label=True) | |
| gr.Markdown("**Examples:**") | |
| # gr.Examples( | |
| # examples, | |
| # [video_file,local_video_labels_text], | |
| # [predictions, video_gif], | |
| # fn=predict, | |
| # cache_examples=True, | |
| # ) | |
| submit_button.click( | |
| predict, | |
| inputs=[video_file, local_video_labels_text], | |
| # outputs=[predictions, video_gif], | |
| outputs=predictions, | |
| ) | |
| # gr.Markdown( | |
| # """ | |
| # \n Created by: Vincent Claes, <a href=\"https://www.meet-drift.ai/\">Drift</a>. | |
| # \n Inspired by: <a href=\"https://huggingface.co/spaces/fcakyon/zero-shot-video-classification\">fcakyon</a>. | |
| # """ | |
| # ) | |
| app.launch() | |