Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import tempfile | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from PIL import Image, ImageDraw | |
| from transformers import ( | |
| Sam3Model, Sam3Processor, | |
| Sam3TrackerModel, Sam3TrackerProcessor | |
| ) | |
| # ============ THEME SETUP ============ | |
| # Theme disabled to avoid API schema issues on HF Spaces | |
| # ============ GLOBAL SETUP ============ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"๐ฅ๏ธ Using compute device: {device}") | |
| # Load models | |
| print("โณ Loading SAM3 Models permanently into memory...") | |
| try: | |
| # ์คํ๋ผ์ธ ๋ชจ๋๋ก ์บ์์์ ๋ก๋ ์๋ | |
| print(" ... Loading from local cache (offline mode)") | |
| IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", local_files_only=True, device_map="cpu", torch_dtype=torch.float32) | |
| IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3", local_files_only=True) | |
| TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", local_files_only=True, device_map="cpu", torch_dtype=torch.float32) | |
| TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3", local_files_only=True) | |
| print("โ All Models loaded successfully from local cache!") | |
| except Exception as e: | |
| print(f"โ Cache loading failed: {e}") | |
| print(" Trying online loading...") | |
| try: | |
| IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map="cpu", torch_dtype=torch.float32) | |
| IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3") | |
| TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map="cpu", torch_dtype=torch.float32) | |
| TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3") | |
| print("โ All Models loaded successfully (CPU mode)!") | |
| except Exception as e2: | |
| print(f"โ Online loading also failed: {e2}") | |
| IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = None | |
| # ============ LAYER MANAGEMENT ============ | |
| class LayerManager: | |
| """๋ ์ด์ด ๊ธฐ๋ฐ ์ธ๊ทธ๋ฉํ ์ด์ ๊ด๋ฆฌ ํด๋์ค""" | |
| def __init__(self): | |
| self.layers = {} # layer_id -> {'name': str, 'color': tuple, 'points': list, 'point_labels': list, 'masks': list, 'area': float} | |
| self.current_layer_id = None | |
| self.layer_counter = 0 | |
| def create_layer(self, name, color=None): | |
| """์ ๋ ์ด์ด ์์ฑ""" | |
| if color is None: | |
| # ๋ฌด์์ ์์ ์์ฑ | |
| import random | |
| color = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200)) | |
| layer_id = f"layer_{self.layer_counter}" | |
| self.layers[layer_id] = { | |
| 'name': name, | |
| 'color': color, | |
| 'points': [], | |
| 'point_labels': [], # 1: positive, 0: negative | |
| 'masks': [], | |
| 'area': 0.0 | |
| } | |
| self.layer_counter += 1 | |
| return layer_id | |
| def add_point_to_layer(self, layer_id, point, label=1): | |
| """๋ ์ด์ด์ ํฌ์ธํธ ์ถ๊ฐ""" | |
| if layer_id in self.layers: | |
| self.layers[layer_id]['points'].append(point) | |
| self.layers[layer_id]['point_labels'].append(label) | |
| print(f"[add_point_to_layer] Added point to '{self.layers[layer_id]['name']}': {point}, label={label}") | |
| print(f"[add_point_to_layer] Total points in '{self.layers[layer_id]['name']}': {len(self.layers[layer_id]['points'])}") | |
| def add_mask_to_layer(self, layer_id, mask): | |
| """๋ ์ด์ด์ ๋ง์คํฌ ์ถ๊ฐ""" | |
| if layer_id in self.layers: | |
| # ๊ธฐ์กด ๋ง์คํฌ๋ฅผ ๊ต์ฒด (๊ฐ์ ๋ ์ด์ด์ ์ฌ์ธ๊ทธ๋ฉํ ์ด์ ์) | |
| self.layers[layer_id]['masks'] = [mask] | |
| # ๋ฉด์ ๊ณ์ฐ - mask๋ฅผ numpy array๋ก ๋ณํ | |
| if isinstance(mask, torch.Tensor): | |
| mask_np = mask.cpu().numpy() | |
| else: | |
| mask_np = mask | |
| # ๋ฉด์ ๊ณ์ฐ | |
| area = np.sum(mask_np > 0) | |
| self.layers[layer_id]['area'] = area | |
| # ๋๋ฒ๊น : ๋ง์คํฌ ์ ๋ณด ์ถ๋ ฅ | |
| print(f"[add_mask_to_layer] Layer: {self.layers[layer_id]['name']}, Mask shape: {mask_np.shape}, Area: {area}") | |
| def get_current_layer(self): | |
| """ํ์ฌ ์ ํ๋ ๋ ์ด์ด ๋ฐํ""" | |
| if self.current_layer_id and self.current_layer_id in self.layers: | |
| return self.layers[self.current_layer_id] | |
| return None | |
| def set_current_layer(self, layer_id): | |
| """ํ์ฌ ๋ ์ด์ด ์ค์ """ | |
| self.current_layer_id = layer_id | |
| def clear_current_layer(self): | |
| """ํ์ฌ ๋ ์ด์ด ์ด๊ธฐํ""" | |
| if self.current_layer_id and self.current_layer_id in self.layers: | |
| self.layers[self.current_layer_id]['points'] = [] | |
| self.layers[self.current_layer_id]['point_labels'] = [] | |
| self.layers[self.current_layer_id]['masks'] = [] | |
| self.layers[self.current_layer_id]['area'] = 0.0 | |
| def calculate_total_area_ratio(layer_manager, total_pixels): | |
| """์ ์ฒด ์ด๋ฏธ์ง ๋๋น ๊ฐ ๋ ์ด์ด์ ๋ฉด์ ๋น์จ ๊ณ์ฐ""" | |
| ratios = [] | |
| for layer_id, layer in layer_manager.layers.items(): | |
| area = layer['area'] | |
| ratio = (area / total_pixels) * 100 if total_pixels > 0 and area > 0 else 0 | |
| has_mask = len(layer['masks']) > 0 | |
| # ๋๋ฒ๊น : ๋ ์ด์ด ์ ๋ณด ์ถ๋ ฅ | |
| print(f"[calculate_total_area_ratio] Layer: {layer['name']}, Area: {area}, Ratio: {ratio}%, Masks: {len(layer['masks'])}, Has mask: {has_mask}") | |
| ratios.append({ | |
| 'layer_name': layer['name'], | |
| 'area_pixels': int(area), | |
| 'ratio_percent': round(ratio, 2) | |
| }) | |
| return ratios | |
| def create_area_chart_data(ratios): | |
| """๋ฉด์ ๋ฐ์ดํฐ๋ฅผ ํ ์ด๋ธ ํฌ๋งท์ผ๋ก ๋ณํ""" | |
| if not ratios: | |
| return pd.DataFrame(columns=["Layer", "Area (pixels)", "Ratio(%)"]) | |
| data = [] | |
| for ratio in ratios: | |
| data.append({ | |
| "Layer": ratio['layer_name'], | |
| "Area (pixels)": f"{ratio['area_pixels']:,}", | |
| "Ratio(%)": f"{ratio['ratio_percent']}%" | |
| }) | |
| return pd.DataFrame(data) | |
| # ============ UTILITY FUNCTIONS ============ | |
| def compose_all_layers(base_image, layer_manager, opacity=0.5, border_width=2): | |
| """๋ชจ๋ ๋ ์ด์ด๋ฅผ ํฉ์ฑํ์ฌ ์ต์ข ์ด๋ฏธ์ง ์์ฑ""" | |
| if isinstance(base_image, np.ndarray): | |
| base_image = Image.fromarray(base_image) | |
| base_image = base_image.convert("RGBA") | |
| if not layer_manager.layers: | |
| return base_image.convert("RGB") | |
| composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0)) | |
| for layer_id, layer in layer_manager.layers.items(): | |
| if not layer['masks']: | |
| continue | |
| layer_color = layer['color'] | |
| for mask in layer['masks']: | |
| if isinstance(mask, torch.Tensor): | |
| mask = mask.cpu().numpy() | |
| mask = mask.astype(np.uint8) | |
| if mask.ndim == 3: mask = mask[0] | |
| if mask.ndim == 2 and mask.shape[0] == 1: mask = mask[0] | |
| # ๋ง์คํฌ๋ฅผ PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
| mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # ์์ ๋ ์ด์ด ์์ฑ | |
| color_layer = Image.new("RGBA", base_image.size, layer_color + (0,)) | |
| mask_alpha = mask_img.point(lambda v: int(v * opacity * 255) if v > 0 else 0) | |
| color_layer.putalpha(mask_alpha) | |
| # ํ ๋๋ฆฌ ์ถ๊ฐ | |
| if border_width > 0: | |
| try: | |
| # ๋ง์คํฌ์ ํ ๋๋ฆฌ ์ฐพ๊ธฐ | |
| mask_np = np.array(mask_img) | |
| kernel_size = border_width * 2 + 1 | |
| dilated = cv2.dilate(mask_np, np.ones((kernel_size, kernel_size), np.uint8)) | |
| border = dilated - mask_np | |
| border_img = Image.fromarray(border) | |
| border_layer = Image.new("RGBA", base_image.size, (255, 255, 255, 255)) # ํฐ์ ํ ๋๋ฆฌ | |
| border_alpha = border_img.point(lambda v: 255 if v > 0 else 0) | |
| border_layer.putalpha(border_alpha) | |
| # ํ ๋๋ฆฌ๋ฅผ ๋จผ์ ํฉ์ฑ | |
| composite_layer = Image.alpha_composite(composite_layer, border_layer) | |
| except Exception as e: | |
| print(f"Border creation error: {e}") | |
| # ๋ง์คํฌ ๋ ์ด์ด ํฉ์ฑ | |
| composite_layer = Image.alpha_composite(composite_layer, color_layer) | |
| # ์ต์ข ํฉ์ฑ | |
| final_result = Image.alpha_composite(base_image, composite_layer) | |
| return final_result.convert("RGB") | |
| def draw_points_on_image(image, layer_manager): | |
| """์ด๋ฏธ์ง์ ๋ชจ๋ ๋ ์ด์ด์ ํฌ์ธํธ๋ค์ ํ์""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| draw_img = image.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| for layer_id, layer in layer_manager.layers.items(): | |
| is_current = (layer_id == layer_manager.current_layer_id) | |
| for i, point in enumerate(layer['points']): | |
| x, y = point | |
| label = layer['point_labels'][i] | |
| # ํฌ์งํฐ๋ธ: ๋นจ๊ฐ์ ์, ๋ค๊ฑฐํฐ๋ธ: ํ๋์ Xํ์ | |
| if label == 1: # Positive | |
| # ํฐ ๋นจ๊ฐ์ ์ | |
| r = 15 if is_current else 10 | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=3) | |
| # ์์ ํฐ์ ์ (์ค์) | |
| draw.ellipse((x-3, y-3, x+3, y+3), fill="white") | |
| else: # Negative (0) | |
| # ํฐ ํ๋์ ์ | |
| r = 15 if is_current else 10 | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill="blue", outline="white", width=3) | |
| # X ํ์ | |
| line_length = 8 | |
| draw.line([(x-line_length, y-line_length), (x+line_length, y+line_length)], fill="white", width=3) | |
| draw.line([(x-line_length, y+line_length), (x+line_length, y-line_length)], fill="white", width=3) | |
| return draw_img | |
| # ============ UI FUNCTIONS ============ | |
| def create_new_layer(name, current_manager): | |
| """์ ๋ ์ด์ด ์์ฑ""" | |
| if not name.strip(): | |
| return current_manager, create_layer_status_html(current_manager), gr.Dropdown(choices=[]), "Please enter a layer name" | |
| # ์ค๋ณต ์ด๋ฆ ์ฒดํฌ | |
| for layer_id, layer in current_manager.layers.items(): | |
| if layer['name'] == name.strip(): | |
| return current_manager, create_layer_status_html(current_manager), gr.Dropdown(choices=[(layer['name'], lid) for lid, layer in current_manager.layers.items()]), f"Layer name '{name}' already exists" | |
| layer_id = current_manager.create_layer(name.strip()) | |
| current_manager.set_current_layer(layer_id) | |
| # ๋๋กญ๋ค์ด ์ ํ์ง ์ ๋ฐ์ดํธ | |
| choices = [(layer['name'], lid) for lid, layer in current_manager.layers.items()] | |
| return current_manager, create_layer_status_html(current_manager), gr.Dropdown(choices=choices, value=layer_id), f"Layer '{name}' created" | |
| def create_layer_status_html(current_manager): | |
| """๋ ์ด์ด ์ํ ํ์ HTML ์์ฑ (์๊ฐ์ ํ์๋ง)""" | |
| if not current_manager.layers: | |
| return "<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>" | |
| html = "<div style='display: flex; flex-wrap: wrap; gap: 8px; padding: 10px;'>" | |
| for layer_id, layer in current_manager.layers.items(): | |
| is_active = (current_manager.current_layer_id == layer_id) | |
| # ์์ ์ถ์ถ | |
| r, g, b = layer['color'] | |
| color_hex = f"#{r:02x}{g:02x}{b:02x}" | |
| # ํ์ฑํ ์ํ์ ๋ฐ๋ฅธ ์คํ์ผ | |
| if is_active: | |
| style = f""" | |
| background: linear-gradient(135deg, {color_hex}, {color_hex}dd); | |
| color: white; | |
| border: 3px solid #4682B4; | |
| box-shadow: 0 4px 12px rgba(70, 130, 180, 0.4); | |
| """ | |
| else: | |
| style = f""" | |
| background: linear-gradient(135deg, {color_hex}aa, {color_hex}77); | |
| color: white; | |
| border: 2px solid {color_hex}; | |
| opacity: 0.7; | |
| """ | |
| # ํฌ์ธํธ ๊ฐ์ ๊ณ์ฐ (ํฌ์งํฐ๋ธ/๋ค๊ฑฐํฐ๋ธ ๊ตฌ๋ถ) | |
| positive_points = sum(1 for label in layer['point_labels'] if label == 1) | |
| negative_points = sum(1 for label in layer['point_labels'] if label == 0) | |
| masks_count = len(layer['masks']) | |
| has_mask = masks_count > 0 | |
| # ์ํ ์์ด์ฝ | |
| status_icon = "[OK]" if has_mask else "[ ]" | |
| html += f""" | |
| <div style="{style} | |
| padding: 12px 20px; | |
| border-radius: 8px; | |
| font-weight: 600; | |
| font-size: 14px; | |
| min-width: 150px;"> | |
| {status_icon} {layer['name']}<br> | |
| <small style='font-size: 11px; opacity: 0.9;'> | |
| <span style='color: #ffcccc;'>+{positive_points}</span> | |
| <span style='color: #ccccff;'>-{negative_points}</span> | |
| {masks_count}mask | |
| </small> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def click_on_image(current_manager, image, point_mode, evt: gr.SelectData): | |
| """์ด๋ฏธ์ง ํด๋ฆญ ์ฒ๋ฆฌ - Include/Exclude ๋ชจ๋์ ๋ฐ๋ผ ํฌ์ธํธ ์ถ๊ฐ""" | |
| if image is None or current_manager.current_layer_id is None: | |
| return image, current_manager, create_layer_status_html(current_manager), "Please select image and layer" | |
| x, y = evt.index | |
| # ํฌ์ธํธ ๋ชจ๋์ ๋ฐ๋ผ ๋ ์ด๋ธ ๊ฒฐ์ (positive=1, negative=0) | |
| label = 1 if point_mode == "positive" else 0 | |
| layer_name = current_manager.layers[current_manager.current_layer_id]['name'] | |
| print(f"\n[click_on_image] ================") | |
| print(f"[click_on_image] Layer: {layer_name}") | |
| print(f"[click_on_image] Point mode: {point_mode}, Label: {label}, Position: ({x}, {y})") | |
| current_manager.add_point_to_layer(current_manager.current_layer_id, [x, y], label) | |
| # ํฌ์ธํธ ํ์๋ ์ด๋ฏธ์ง ์์ฑ (์๋ณธ ์ด๋ฏธ์ง์ ํฌ์ธํธ ํ์) | |
| result_image = draw_points_on_image(image, current_manager) | |
| mode_text = "Include" if label == 1 else "Exclude" | |
| return result_image, current_manager, create_layer_status_html(current_manager), f"{mode_text} point added to '{layer_name}' at ({x}, {y})" | |
| def segment_all_layers(current_manager, image, opacity, border_width): | |
| """๋ชจ๋ ๋ ์ด์ด๋ฅผ ์์๋๋ก ์ธ๊ทธ๋ฉํ ์ด์ ์คํ""" | |
| if image is None: | |
| return None, current_manager, create_layer_status_html(current_manager), "Please upload an image", pd.DataFrame() | |
| if not current_manager.layers: | |
| return None, current_manager, create_layer_status_html(current_manager), "Please create layers first", pd.DataFrame() | |
| try: | |
| print(f"\n[segment_all_layers] Starting segmentation for all layers...") | |
| segmented_count = 0 | |
| skipped_count = 0 | |
| # ๋ชจ๋ ๋ ์ด์ด๋ฅผ ์ํํ๋ฉฐ ์ธ๊ทธ๋ฉํ ์ด์ | |
| for layer_id, layer in current_manager.layers.items(): | |
| layer_name = layer['name'] | |
| # ํฌ์ธํธ๊ฐ ์๋ ๋ ์ด์ด๋ ๊ฑด๋๋ฐ๊ธฐ | |
| if not layer['points']: | |
| print(f"[segment_all_layers] Skipping '{layer_name}' - no points") | |
| skipped_count += 1 | |
| continue | |
| print(f"\n[segment_all_layers] Processing layer: {layer_name}") | |
| print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}") | |
| # SAM3 Tracker๋ก ์ธ๊ทธ๋ฉํ ์ด์ | |
| points_list = layer['points'] | |
| labels_list = layer['point_labels'] | |
| input_points = [[points_list]] | |
| input_labels = [[labels_list]] | |
| inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = TRK_MODEL(**inputs, multimask_output=False) | |
| masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0] | |
| # ๋ ์ด์ด์ ๋ง์คํฌ ์ถ๊ฐ | |
| current_manager.add_mask_to_layer(layer_id, masks[0]) | |
| segmented_count += 1 | |
| print(f"[segment_all_layers] Completed '{layer_name}'") | |
| # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง ์์ฑ (ํฌ์ธํธ ํฌํจ) | |
| result_image = compose_all_layers(image, current_manager, opacity, border_width) | |
| result_image = draw_points_on_image(result_image, current_manager) | |
| # ๋ฉด์ ๋ถ์ | |
| total_pixels = image.size[0] * image.size[1] | |
| ratios = calculate_total_area_ratio(current_manager, total_pixels) | |
| chart_data = create_area_chart_data(ratios) | |
| status_msg = f"Segmentation completed! Processed: {segmented_count} layers, Skipped: {skipped_count} layers" | |
| print(f"\n[segment_all_layers] {status_msg}") | |
| return result_image, current_manager, create_layer_status_html(current_manager), status_msg, chart_data | |
| except Exception as e: | |
| import traceback | |
| print(f"[segment_all_layers] Error: {str(e)}") | |
| traceback.print_exc() | |
| return None, current_manager, create_layer_status_html(current_manager), f"Error: {str(e)}", pd.DataFrame() | |
| def clear_current_layer(current_manager, image, opacity, border_width): | |
| """ํ์ฌ ๋ ์ด์ด ์ด๊ธฐํ""" | |
| if current_manager.current_layer_id: | |
| current_manager.clear_current_layer() | |
| if image: | |
| result_image = compose_all_layers(image, current_manager, opacity, border_width) | |
| result_image = draw_points_on_image(result_image, current_manager) | |
| else: | |
| result_image = None | |
| total_pixels = image.size[0] * image.size[1] if image else 0 | |
| ratios = calculate_total_area_ratio(current_manager, total_pixels) | |
| chart_data = create_area_chart_data(ratios) | |
| return result_image, current_manager, create_layer_status_html(current_manager), "Layer cleared", chart_data | |
| return None, current_manager, create_layer_status_html(current_manager), "Please select a layer", pd.DataFrame() | |
| def refresh_visualization(current_manager, image, opacity, border_width): | |
| """์๊ฐํ ์๋ก๊ณ ์นจ""" | |
| if image is None: | |
| return None, "Please upload an image", pd.DataFrame() | |
| result_image = compose_all_layers(image, current_manager, opacity, border_width) | |
| result_image = draw_points_on_image(result_image, current_manager) | |
| total_pixels = image.size[0] * image.size[1] | |
| ratios = calculate_total_area_ratio(current_manager, total_pixels) | |
| chart_data = create_area_chart_data(ratios) | |
| return result_image, "Visualization updated", chart_data | |
| # ============ GRADIO INTERFACE ============ | |
| custom_css=""" | |
| #col-container { margin: 0 auto; max-width: 1200px; } | |
| #main-title h1 { font-size: 2.1em !important; } | |
| .layer-button { margin: 2px; } | |
| """ | |
| # No custom JavaScript needed anymore | |
| custom_js = "" | |
| # ์ ์ญ ๋ ์ด์ด ๋งค๋์ | |
| layer_manager = LayerManager() | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# **SAM3 Layer Segmentation Tool**", elem_id="main-title") | |
| gr.Markdown("**Layer-based object separation and area analysis tool** | 1. Create layers 2. Select point mode and click 3. Run segmentation (processes all layers)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=400) | |
| # ๋ ์ด์ด ์์ฑ | |
| with gr.Row(): | |
| layer_name_input = gr.Textbox(label="Layer Name", placeholder="e.g. bench, tree, person") | |
| create_layer_btn = gr.Button("Create", variant="primary") | |
| # ๋ ์ด์ด ์ํ ํ์ | |
| gr.Markdown("### Layers Status") | |
| layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>") | |
| # ๋ ์ด์ด ์ ํ | |
| layer_selector = gr.Dropdown(label="Select Layer to Add Points", choices=[], interactive=True) | |
| # ํฌ์ธํธ ๋ชจ๋ ์ ํ | |
| gr.Markdown("### Point Mode") | |
| with gr.Row(): | |
| include_btn = gr.Button("Include Point", variant="primary", size="sm") | |
| exclude_btn = gr.Button("Exclude Point", variant="secondary", size="sm") | |
| point_mode_text = gr.Textbox(label="Current Mode", value="Include Point (Red)", interactive=False) | |
| # ํฌ์ธํธ ์๋ด | |
| gr.Markdown(""" | |
| **Instructions:** | |
| - Select a layer from dropdown | |
| - Choose point mode (Include/Exclude) | |
| - Click on image to add point | |
| - **Red circle (โ)**: Include this area | |
| - **Blue circle with X**: Exclude this area | |
| """) | |
| # ์ปจํธ๋กค | |
| with gr.Row(): | |
| segment_btn = gr.Button("Run All Segmentation", variant="primary", size="lg") | |
| clear_btn = gr.Button("Clear Current Layer", variant="secondary") | |
| # ์ํ | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| st_layer_manager = gr.State(layer_manager) | |
| point_mode_state = gr.State("positive") # "positive" or "negative" | |
| with gr.Column(scale=2): | |
| img_output = gr.Image(type="pil", label="Segmentation Result", height=400, interactive=False) | |
| # ๋ฉด์ ํ ์ด๋ธ | |
| area_table = gr.Dataframe( | |
| label="Area Ratio by Layer" | |
| ) | |
| # ์ค์ | |
| with gr.Accordion("Visualization Settings", open=False): | |
| opacity_slider = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Mask Opacity") | |
| border_slider = gr.Slider(0, 5, value=2, step=1, label="Border Width") | |
| # ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
| create_layer_btn.click( | |
| create_new_layer, | |
| inputs=[layer_name_input, st_layer_manager], | |
| outputs=[st_layer_manager, layer_buttons_html, layer_selector, status_text] | |
| ) | |
| # ๋ ์ด์ด ์ ํ | |
| def on_layer_select(layer_id, mgr): | |
| if layer_id: | |
| mgr.set_current_layer(layer_id) | |
| return mgr, create_layer_status_html(mgr), f"Layer '{mgr.layers[layer_id]['name']}' selected" | |
| return mgr, create_layer_status_html(mgr), "Please select a layer" | |
| layer_selector.change( | |
| on_layer_select, | |
| inputs=[layer_selector, st_layer_manager], | |
| outputs=[st_layer_manager, layer_buttons_html, status_text] | |
| ) | |
| # ํฌ์ธํธ ๋ชจ๋ ๋ณ๊ฒฝ | |
| def set_include_mode(): | |
| return "positive", "Include Point (Red)" | |
| def set_exclude_mode(): | |
| return "negative", "Exclude Point (Blue)" | |
| include_btn.click( | |
| set_include_mode, | |
| outputs=[point_mode_state, point_mode_text] | |
| ) | |
| exclude_btn.click( | |
| set_exclude_mode, | |
| outputs=[point_mode_state, point_mode_text] | |
| ) | |
| # ์ด๋ฏธ์ง ํด๋ฆญ ์ด๋ฒคํธ - img_input๊ณผ img_output ๋ชจ๋์์ ํด๋ฆญ ๋ฐ๊ธฐ | |
| img_input.select( | |
| click_on_image, | |
| inputs=[st_layer_manager, img_input, point_mode_state], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text] | |
| ) | |
| img_output.select( | |
| click_on_image, | |
| inputs=[st_layer_manager, img_input, point_mode_state], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text] | |
| ) | |
| # ๋ชจ๋ ๋ ์ด์ด ์ธ๊ทธ๋ฉํ ์ด์ ์คํ | |
| segment_btn.click( | |
| segment_all_layers, | |
| inputs=[st_layer_manager, img_input, opacity_slider, border_slider], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table] | |
| ) | |
| clear_btn.click( | |
| clear_current_layer, | |
| inputs=[st_layer_manager, img_input, opacity_slider, border_slider], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table] | |
| ) | |
| # ํฌ๋ช ๋ ๋ฐ ํ ๋๋ฆฌ ์ฌ๋ผ์ด๋ ์ค์๊ฐ ์ ๋ฐ์ดํธ | |
| opacity_slider.change( | |
| refresh_visualization, | |
| inputs=[st_layer_manager, img_input, opacity_slider, border_slider], | |
| outputs=[img_output, status_text, area_table] | |
| ) | |
| border_slider.change( | |
| refresh_visualization, | |
| inputs=[st_layer_manager, img_input, opacity_slider, border_slider], | |
| outputs=[img_output, status_text, area_table] | |
| ) | |
| # ์ด๋ฏธ์ง ์ ๋ก๋ ์ ์ด๊ธฐํ | |
| def on_image_upload(img): | |
| new_manager = LayerManager() | |
| empty_html = "<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>" | |
| # ์ ๋ก๋ํ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅ์๋ ํ์ | |
| return new_manager, img, pd.DataFrame(), empty_html, gr.Dropdown(choices=[], value=None), "positive", "Include Point (Red)", "New image uploaded" | |
| img_input.change( | |
| on_image_upload, | |
| inputs=[img_input], | |
| outputs=[st_layer_manager, img_output, area_table, layer_buttons_html, layer_selector, point_mode_state, point_mode_text, status_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |