Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import tempfile | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from PIL import Image, ImageDraw | |
| import inspect | |
| from transformers import ( | |
| Sam3Model, Sam3Processor, | |
| Sam3TrackerModel, Sam3TrackerProcessor | |
| ) | |
| # ============ GRADIO CLIENT COMPAT PATCH ============ | |
| # Some gradio_client versions crash when encountering boolean JSON Schema nodes | |
| # (e.g., additionalProperties: false/true) while generating /info API schema. | |
| # Patch defensively to prevent startup failures on HF Spaces. | |
| try: | |
| import gradio_client.utils as _grc_utils | |
| if not getattr(_grc_utils, "_BOOL_SCHEMA_PATCHED", False): | |
| # 1) get_type() sometimes assumes schema is a dict | |
| _orig_get_type = getattr(_grc_utils, "get_type", None) | |
| def _safe_get_type(schema): # noqa: ANN001 | |
| if isinstance(schema, bool) or not isinstance(schema, dict): | |
| return "any" | |
| if _orig_get_type is None: | |
| return "any" | |
| return _orig_get_type(schema) | |
| if _orig_get_type is not None: | |
| _grc_utils.get_type = _safe_get_type | |
| # 2) _json_schema_to_python_type() may raise on schema == True/False | |
| _orig_js2pt = getattr(_grc_utils, "_json_schema_to_python_type", None) | |
| def _safe_json_schema_to_python_type(schema, defs=None): # noqa: ANN001 | |
| if isinstance(schema, bool): | |
| # boolean JSON Schema: True means "any", False means "no value". | |
| # For API-info rendering, treating both as Any avoids crashes. | |
| return "Any" | |
| if _orig_js2pt is None: | |
| return "Any" | |
| return _orig_js2pt(schema, defs) | |
| if _orig_js2pt is not None: | |
| _grc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type | |
| _grc_utils._BOOL_SCHEMA_PATCHED = True | |
| except Exception as _e: | |
| # If gradio_client is unavailable or API changes, ignore and proceed. | |
| print(f"[warn] gradio_client schema patch skipped: {_e}") | |
| # ============ GRADIO LAUNCH COMPAT ============ | |
| def _launch_compat(demo, **kwargs): | |
| """ | |
| Gradio ๋ฒ์ ๋ณ๋ก Blocks.launch() ์๊ทธ๋์ฒ๊ฐ ๋ฌ๋ผ์(์: ssr ์ง์ ์ฌ๋ถ), | |
| ํ์ฌ ์ค์น๋ Gradio๊ฐ ์ง์ํ๋ ํค๋ง ๊ณจ๋ผ launch()๋ฅผ ํธ์ถํ๋ค. | |
| """ | |
| # ๋จผ์ ์ํ๋ ์ต์ ์ผ๋ก ๊ทธ๋๋ก ํธ์ถํ๊ณ , "unexpected keyword"๊ฐ ๋๋ฉด ํด๋น ํค๋ง ์ ๊ฑฐ ํ ์ฌ์๋ | |
| try: | |
| return demo.launch(**kwargs) | |
| except TypeError as e: | |
| msg = str(e) | |
| if "unexpected keyword argument 'ssr'" in msg or 'unexpected keyword argument "ssr"' in msg: | |
| kwargs.pop("ssr", None) | |
| return demo.launch(**kwargs) | |
| # ๊ทธ ์ธ ์ผ์ด์ค๋ ์๊ทธ๋์ฒ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ฒ ๋ ๋ณด์์ ์ผ๋ก ํํฐ๋ง | |
| try: | |
| sig = inspect.signature(demo.launch) | |
| supported = set(sig.parameters.keys()) | |
| filtered = {k: v for k, v in kwargs.items() if k in supported} | |
| return demo.launch(**filtered) | |
| except Exception: | |
| raise | |
| # ============ THEME SETUP ============ | |
| # Theme disabled to avoid API schema issues on HF Spaces | |
| # ============ GLOBAL SETUP ============ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"๐ฅ๏ธ Using compute device: {device}") | |
| # Models will be loaded lazily in functions to avoid build timeouts | |
| IMG_MODEL = None | |
| IMG_PROCESSOR = None | |
| TRK_MODEL = None | |
| TRK_PROCESSOR = None | |
| def load_models(): | |
| """Lazy load models when needed""" | |
| global IMG_MODEL, IMG_PROCESSOR, TRK_MODEL, TRK_PROCESSOR | |
| if IMG_MODEL is not None: | |
| return True | |
| print("โณ Loading SAM3 Models...") | |
| try: | |
| # GPU๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ๋ฉด GPU๋ก ๋ก๋ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype) | |
| IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3") | |
| TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype) | |
| TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3") | |
| print(f"โ All Models loaded successfully on {device}!") | |
| return True | |
| except Exception as e: | |
| print(f"โ Model loading failed: {e}") | |
| return False | |
| # ============ LAYER MANAGEMENT ============ | |
| class LayerManager: | |
| """๋ ์ด์ด ๊ธฐ๋ฐ ์ธ๊ทธ๋ฉํ ์ด์ ๊ด๋ฆฌ ํด๋์ค""" | |
| def __init__(self): | |
| self.layers = {} # layer_id -> {'name': str, 'color': tuple, 'points': list, 'point_labels': list, 'masks': list, 'area': float} | |
| self.current_layer_id = None | |
| self.layer_counter = 0 | |
| def create_layer(self, name, color=None): | |
| """์ ๋ ์ด์ด ์์ฑ""" | |
| if color is None: | |
| # ๋ฌด์์ ์์ ์์ฑ | |
| import random | |
| color = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200)) | |
| layer_id = f"layer_{self.layer_counter}" | |
| self.layers[layer_id] = { | |
| 'name': name, | |
| 'color': color, | |
| 'points': [], | |
| 'point_labels': [], # 1: positive, 0: negative | |
| 'masks': [], | |
| 'area': 0.0 | |
| } | |
| self.layer_counter += 1 | |
| return layer_id | |
| def add_point_to_layer(self, layer_id, point, label=1): | |
| """๋ ์ด์ด์ ํฌ์ธํธ ์ถ๊ฐ""" | |
| if layer_id in self.layers: | |
| self.layers[layer_id]['points'].append(point) | |
| self.layers[layer_id]['point_labels'].append(label) | |
| print(f"[add_point_to_layer] Added point to '{self.layers[layer_id]['name']}': {point}, label={label}") | |
| print(f"[add_point_to_layer] Total points in '{self.layers[layer_id]['name']}': {len(self.layers[layer_id]['points'])}") | |
| def add_mask_to_layer(self, layer_id, mask): | |
| """๋ ์ด์ด์ ๋ง์คํฌ ์ถ๊ฐ""" | |
| if layer_id in self.layers: | |
| # ๊ธฐ์กด ๋ง์คํฌ๋ฅผ ๊ต์ฒด (๊ฐ์ ๋ ์ด์ด์ ์ฌ์ธ๊ทธ๋ฉํ ์ด์ ์) | |
| self.layers[layer_id]['masks'] = [mask] | |
| # ๋ฉด์ ๊ณ์ฐ - mask๋ฅผ numpy array๋ก ๋ณํ | |
| if isinstance(mask, torch.Tensor): | |
| mask_np = mask.cpu().numpy() | |
| else: | |
| mask_np = mask | |
| # ๋ฉด์ ๊ณ์ฐ | |
| area = np.sum(mask_np > 0) | |
| self.layers[layer_id]['area'] = area | |
| # ๋๋ฒ๊น : ๋ง์คํฌ ์ ๋ณด ์ถ๋ ฅ | |
| print(f"[add_mask_to_layer] Layer: {self.layers[layer_id]['name']}, Mask shape: {mask_np.shape}, Area: {area}") | |
| def get_current_layer(self): | |
| """ํ์ฌ ์ ํ๋ ๋ ์ด์ด ๋ฐํ""" | |
| if self.current_layer_id and self.current_layer_id in self.layers: | |
| return self.layers[self.current_layer_id] | |
| return None | |
| def set_current_layer(self, layer_id): | |
| """ํ์ฌ ๋ ์ด์ด ์ค์ """ | |
| self.current_layer_id = layer_id | |
| def clear_current_layer(self): | |
| """ํ์ฌ ๋ ์ด์ด ์ด๊ธฐํ""" | |
| if self.current_layer_id and self.current_layer_id in self.layers: | |
| self.layers[self.current_layer_id]['points'] = [] | |
| self.layers[self.current_layer_id]['point_labels'] = [] | |
| self.layers[self.current_layer_id]['masks'] = [] | |
| self.layers[self.current_layer_id]['area'] = 0.0 | |
| def calculate_total_area_ratio(layer_manager, total_pixels): | |
| """์ ์ฒด ์ด๋ฏธ์ง ๋๋น ๊ฐ ๋ ์ด์ด์ ๋ฉด์ ๋น์จ ๊ณ์ฐ""" | |
| ratios = [] | |
| for layer_id, layer in layer_manager.layers.items(): | |
| area = layer['area'] | |
| ratio = (area / total_pixels) * 100 if total_pixels > 0 and area > 0 else 0 | |
| has_mask = len(layer['masks']) > 0 | |
| # ๋๋ฒ๊น : ๋ ์ด์ด ์ ๋ณด ์ถ๋ ฅ | |
| print(f"[calculate_total_area_ratio] Layer: {layer['name']}, Area: {area}, Ratio: {ratio}%, Masks: {len(layer['masks'])}, Has mask: {has_mask}") | |
| ratios.append({ | |
| 'layer_name': layer['name'], | |
| 'area_pixels': int(area), | |
| 'ratio_percent': round(ratio, 2) | |
| }) | |
| return ratios | |
| def create_area_chart_data(ratios): | |
| """๋ฉด์ ๋ฐ์ดํฐ๋ฅผ ํ ์ด๋ธ ํฌ๋งท์ผ๋ก ๋ณํ""" | |
| if not ratios: | |
| return pd.DataFrame(columns=["Layer", "Area (pixels)", "Ratio(%)"]) | |
| data = [] | |
| for ratio in ratios: | |
| data.append({ | |
| "Layer": ratio['layer_name'], | |
| "Area (pixels)": f"{ratio['area_pixels']:,}", | |
| "Ratio(%)": f"{ratio['ratio_percent']}%" | |
| }) | |
| return pd.DataFrame(data) | |
| # ============ UTILITY FUNCTIONS ============ | |
| def compose_all_layers(base_image, layer_manager, opacity=0.5, border_width=2): | |
| """๋ชจ๋ ๋ ์ด์ด๋ฅผ ํฉ์ฑํ์ฌ ์ต์ข ์ด๋ฏธ์ง ์์ฑ""" | |
| if isinstance(base_image, np.ndarray): | |
| base_image = Image.fromarray(base_image) | |
| base_image = base_image.convert("RGBA") | |
| if not layer_manager.layers: | |
| return base_image.convert("RGB") | |
| composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0)) | |
| for layer_id, layer in layer_manager.layers.items(): | |
| if not layer['masks']: | |
| continue | |
| layer_color = layer['color'] | |
| for mask in layer['masks']: | |
| if isinstance(mask, torch.Tensor): | |
| mask = mask.cpu().numpy() | |
| mask = mask.astype(np.uint8) | |
| if mask.ndim == 3: mask = mask[0] | |
| if mask.ndim == 2 and mask.shape[0] == 1: mask = mask[0] | |
| # ๋ง์คํฌ๋ฅผ PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
| mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # ์์ ๋ ์ด์ด ์์ฑ | |
| color_layer = Image.new("RGBA", base_image.size, layer_color + (0,)) | |
| mask_alpha = mask_img.point(lambda v: int(v * opacity * 255) if v > 0 else 0) | |
| color_layer.putalpha(mask_alpha) | |
| # ํ ๋๋ฆฌ ์ถ๊ฐ | |
| if border_width > 0: | |
| try: | |
| # ๋ง์คํฌ์ ํ ๋๋ฆฌ ์ฐพ๊ธฐ | |
| mask_np = np.array(mask_img) | |
| kernel_size = border_width * 2 + 1 | |
| dilated = cv2.dilate(mask_np, np.ones((kernel_size, kernel_size), np.uint8)) | |
| border = dilated - mask_np | |
| border_img = Image.fromarray(border) | |
| border_layer = Image.new("RGBA", base_image.size, (255, 255, 255, 255)) # ํฐ์ ํ ๋๋ฆฌ | |
| border_alpha = border_img.point(lambda v: 255 if v > 0 else 0) | |
| border_layer.putalpha(border_alpha) | |
| # ํ ๋๋ฆฌ๋ฅผ ๋จผ์ ํฉ์ฑ | |
| composite_layer = Image.alpha_composite(composite_layer, border_layer) | |
| except Exception as e: | |
| print(f"Border creation error: {e}") | |
| # ๋ง์คํฌ ๋ ์ด์ด ํฉ์ฑ | |
| composite_layer = Image.alpha_composite(composite_layer, color_layer) | |
| # ์ต์ข ํฉ์ฑ | |
| final_result = Image.alpha_composite(base_image, composite_layer) | |
| return final_result.convert("RGB") | |
| def draw_points_on_image(image, layer_manager): | |
| """์ด๋ฏธ์ง์ ๋ชจ๋ ๋ ์ด์ด์ ํฌ์ธํธ๋ค์ ํ์""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| draw_img = image.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| for layer_id, layer in layer_manager.layers.items(): | |
| is_current = (layer_id == layer_manager.current_layer_id) | |
| for i, point in enumerate(layer['points']): | |
| x, y = point | |
| label = layer['point_labels'][i] | |
| # ํฌ์งํฐ๋ธ: ๋นจ๊ฐ์ ์, ๋ค๊ฑฐํฐ๋ธ: ํ๋์ Xํ์ | |
| if label == 1: # Positive | |
| # ํฐ ๋นจ๊ฐ์ ์ | |
| r = 15 if is_current else 10 | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=3) | |
| # ์์ ํฐ์ ์ (์ค์) | |
| draw.ellipse((x-3, y-3, x+3, y+3), fill="white") | |
| else: # Negative (0) | |
| # ํฐ ํ๋์ ์ | |
| r = 15 if is_current else 10 | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill="blue", outline="white", width=3) | |
| # X ํ์ | |
| line_length = 8 | |
| draw.line([(x-line_length, y-line_length), (x+line_length, y+line_length)], fill="white", width=3) | |
| draw.line([(x-line_length, y+line_length), (x+line_length, y-line_length)], fill="white", width=3) | |
| return draw_img | |
| # ============ UI FUNCTIONS ============ | |
| def update_layer_selector_choices(manager): | |
| """๋ ์ด์ด ์ ํ ๋ผ๋์ค ๋ฒํผ์ choices ์ ๋ฐ์ดํธ""" | |
| choices = [layer['name'] for layer in manager.layers.values()] | |
| current_value = None | |
| if manager.current_layer_id and manager.current_layer_id in manager.layers: | |
| current_value = manager.layers[manager.current_layer_id]['name'] | |
| # ์ปดํฌ๋ํธ๋ฅผ ์๋ก ์์ฑํด์ ๋ฐํํ์ง ๋ง๊ณ , gr.update()๋ก ๊ธฐ์กด ์ปดํฌ๋ํธ๋ฅผ ์ ๋ฐ์ดํธํด์ผ ํจ | |
| return gr.update(choices=choices, value=current_value, interactive=True) | |
| def create_new_layer(name, current_manager): | |
| """์ ๋ ์ด์ด ์์ฑ""" | |
| if current_manager is None: | |
| current_manager = LayerManager() | |
| if not name.strip(): | |
| return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), "Please enter a layer name" | |
| # ์ค๋ณต ์ด๋ฆ ์ฒดํฌ | |
| for layer_id, layer in current_manager.layers.items(): | |
| if layer['name'] == name.strip(): | |
| return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), f"Layer name '{name}' already exists" | |
| layer_id = current_manager.create_layer(name.strip()) | |
| current_manager.set_current_layer(layer_id) | |
| return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), f"Layer '{name}' created" | |
| def create_layer_status_html(current_manager): | |
| """๋ ์ด์ด ์ํ ํ์ HTML ์์ฑ (์๊ฐ์ ํ์๋ง)""" | |
| if not current_manager.layers: | |
| return "<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>" | |
| html = "<div style='display: flex; flex-wrap: wrap; gap: 8px; padding: 10px;'>" | |
| for layer_id, layer in current_manager.layers.items(): | |
| is_active = (current_manager.current_layer_id == layer_id) | |
| # ์์ ์ถ์ถ | |
| r, g, b = layer['color'] | |
| color_hex = f"#{r:02x}{g:02x}{b:02x}" | |
| # ํ์ฑํ ์ํ์ ๋ฐ๋ฅธ ์คํ์ผ | |
| if is_active: | |
| style = f""" | |
| background: linear-gradient(135deg, {color_hex}, {color_hex}dd); | |
| color: white; | |
| border: 3px solid #4682B4; | |
| box-shadow: 0 4px 12px rgba(70, 130, 180, 0.4); | |
| """ | |
| else: | |
| style = f""" | |
| background: linear-gradient(135deg, {color_hex}aa, {color_hex}77); | |
| color: white; | |
| border: 2px solid {color_hex}; | |
| opacity: 0.7; | |
| """ | |
| # ํฌ์ธํธ ๊ฐ์ ๊ณ์ฐ (ํฌ์งํฐ๋ธ/๋ค๊ฑฐํฐ๋ธ ๊ตฌ๋ถ) | |
| positive_points = sum(1 for label in layer['point_labels'] if label == 1) | |
| negative_points = sum(1 for label in layer['point_labels'] if label == 0) | |
| masks_count = len(layer['masks']) | |
| has_mask = masks_count > 0 | |
| # ์ํ ์์ด์ฝ | |
| status_icon = "[OK]" if has_mask else "[ ]" | |
| html += f""" | |
| <div style="{style} | |
| padding: 12px 20px; | |
| border-radius: 8px; | |
| font-weight: 600; | |
| font-size: 14px; | |
| min-width: 150px;"> | |
| {status_icon} {layer['name']}<br> | |
| <small style='font-size: 11px; opacity: 0.9;'> | |
| <span style='color: #ffcccc;'>+{positive_points}</span> | |
| <span style='color: #ccccff;'>-{negative_points}</span> | |
| {masks_count}mask | |
| </small> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def click_on_image(current_manager, image, point_mode, evt: gr.SelectData): | |
| """์ด๋ฏธ์ง ํด๋ฆญ ์ฒ๋ฆฌ - Include/Exclude ๋ชจ๋์ ๋ฐ๋ผ ํฌ์ธํธ ์ถ๊ฐ""" | |
| if current_manager is None: | |
| current_manager = LayerManager() | |
| if image is None or current_manager.current_layer_id is None: | |
| return image, current_manager, create_layer_status_html(current_manager), "Please select image and layer" | |
| x, y = evt.index | |
| # ํฌ์ธํธ ๋ชจ๋์ ๋ฐ๋ผ ๋ ์ด๋ธ ๊ฒฐ์ (positive=1, negative=0) | |
| label = 1 if point_mode == "positive" else 0 | |
| layer_name = current_manager.layers[current_manager.current_layer_id]['name'] | |
| print(f"\n[click_on_image] ================") | |
| print(f"[click_on_image] Layer: {layer_name}") | |
| print(f"[click_on_image] Point mode: {point_mode}, Label: {label}, Position: ({x}, {y})") | |
| current_manager.add_point_to_layer(current_manager.current_layer_id, [x, y], label) | |
| # ํฌ์ธํธ ํ์๋ ์ด๋ฏธ์ง ์์ฑ (์๋ณธ ์ด๋ฏธ์ง์ ํฌ์ธํธ ํ์) | |
| result_image = draw_points_on_image(image, current_manager) | |
| mode_text = "Include" if label == 1 else "Exclude" | |
| return result_image, current_manager, create_layer_status_html(current_manager), f"{mode_text} point added to '{layer_name}' at ({x}, {y})" | |
| def segment_all_layers(current_manager, image): | |
| """๋ชจ๋ ๋ ์ด์ด๋ฅผ ์์๋๋ก ์ธ๊ทธ๋ฉํ ์ด์ ์คํ""" | |
| if current_manager is None: | |
| current_manager = LayerManager() | |
| if image is None: | |
| return None, current_manager, create_layer_status_html(current_manager), "Please upload an image", pd.DataFrame() | |
| if not current_manager.layers: | |
| return None, current_manager, create_layer_status_html(current_manager), "Please create layers first", pd.DataFrame() | |
| try: | |
| print(f"\n[segment_all_layers] Starting segmentation for all layers...") | |
| segmented_count = 0 | |
| skipped_count = 0 | |
| # ๋ชจ๋ ๋ ์ด์ด๋ฅผ ์ํํ๋ฉฐ ์ธ๊ทธ๋ฉํ ์ด์ | |
| for layer_id, layer in current_manager.layers.items(): | |
| layer_name = layer['name'] | |
| # ํฌ์ธํธ๊ฐ ์๋ ๋ ์ด์ด๋ ๊ฑด๋๋ฐ๊ธฐ | |
| if not layer['points']: | |
| print(f"[segment_all_layers] Skipping '{layer_name}' - no points") | |
| skipped_count += 1 | |
| continue | |
| print(f"\n[segment_all_layers] Processing layer: {layer_name}") | |
| print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}") | |
| # Load models if needed | |
| if not load_models(): | |
| print(f"[segment_all_layers] Failed to load models for layer: {layer_name}") | |
| continue | |
| # SAM3 Tracker๋ก ์ธ๊ทธ๋ฉํ ์ด์ | |
| points_list = layer['points'] | |
| labels_list = layer['point_labels'] | |
| input_points = [[points_list]] | |
| input_labels = [[labels_list]] | |
| # Use the same device as the model | |
| model_device = next(TRK_MODEL.parameters()).device | |
| inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model_device) | |
| with torch.no_grad(): | |
| outputs = TRK_MODEL(**inputs, multimask_output=False) | |
| masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0] | |
| # ๋ ์ด์ด์ ๋ง์คํฌ ์ถ๊ฐ | |
| current_manager.add_mask_to_layer(layer_id, masks[0]) | |
| segmented_count += 1 | |
| print(f"[segment_all_layers] Completed '{layer_name}'") | |
| # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง ์์ฑ (ํฌ์ธํธ ํฌํจ) | |
| result_image = compose_all_layers(image, current_manager, 0.5, 2) # ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
| result_image = draw_points_on_image(result_image, current_manager) | |
| # ๋ฉด์ ๋ถ์ | |
| total_pixels = image.size[0] * image.size[1] | |
| ratios = calculate_total_area_ratio(current_manager, total_pixels) | |
| chart_data = create_area_chart_data(ratios) | |
| status_msg = f"Segmentation completed! Processed: {segmented_count} layers, Skipped: {skipped_count} layers" | |
| print(f"\n[segment_all_layers] {status_msg}") | |
| return result_image, current_manager, create_layer_status_html(current_manager), status_msg, chart_data | |
| except Exception as e: | |
| import traceback | |
| print(f"[segment_all_layers] Error: {str(e)}") | |
| traceback.print_exc() | |
| return None, current_manager, create_layer_status_html(current_manager), f"Error: {str(e)}", pd.DataFrame() | |
| def clear_current_layer(current_manager, image): | |
| """ํ์ฌ ๋ ์ด์ด ์ด๊ธฐํ""" | |
| if current_manager is None: | |
| current_manager = LayerManager() | |
| if current_manager.current_layer_id: | |
| current_manager.clear_current_layer() | |
| if image: | |
| result_image = compose_all_layers(image, current_manager, 0.5, 2) # ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
| result_image = draw_points_on_image(result_image, current_manager) | |
| else: | |
| result_image = None | |
| total_pixels = image.size[0] * image.size[1] if image else 0 | |
| ratios = calculate_total_area_ratio(current_manager, total_pixels) | |
| chart_data = create_area_chart_data(ratios) | |
| return result_image, current_manager, create_layer_status_html(current_manager), "Layer cleared", chart_data | |
| return None, current_manager, create_layer_status_html(current_manager), "Please select a layer", pd.DataFrame() | |
| def refresh_visualization(current_manager, image, opacity, border_width): | |
| """์๊ฐํ ์๋ก๊ณ ์นจ""" | |
| if current_manager is None: | |
| current_manager = LayerManager() | |
| if image is None: | |
| return None, "Please upload an image", pd.DataFrame() | |
| result_image = compose_all_layers(image, current_manager, opacity, border_width) | |
| result_image = draw_points_on_image(result_image, current_manager) | |
| total_pixels = image.size[0] * image.size[1] | |
| ratios = calculate_total_area_ratio(current_manager, total_pixels) | |
| chart_data = create_area_chart_data(ratios) | |
| return result_image, "Visualization updated", chart_data | |
| # ============ GRADIO INTERFACE ============ | |
| # No custom JavaScript needed anymore | |
| custom_js = "" | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# **SAM3 Layer Segmentation Tool**", elem_id="main-title") | |
| gr.Markdown("**Layer-based object separation and area analysis tool** | 1. Create layers 2. Select point mode and click 3. Run segmentation (processes all layers)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=400) | |
| # ๋ ์ด์ด ์์ฑ | |
| with gr.Row(): | |
| layer_name_input = gr.Textbox(label="Layer Name", placeholder="e.g. bench, tree, person") | |
| create_layer_btn = gr.Button("Create", variant="primary") | |
| # ๋ ์ด์ด ์ํ ํ์ | |
| gr.Markdown("### Layers Status") | |
| layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>") | |
| # ๋ ์ด์ด ์ ํ (๋ผ๋์ค ๋ฒํผ์ผ๋ก ๋ณ๊ฒฝ) | |
| layer_selector = gr.Radio(label="Select Layer to Add Points", choices=[], interactive=True, value=None) | |
| # ํฌ์ธํธ ๋ชจ๋ ์ ํ | |
| gr.Markdown("### Point Mode") | |
| with gr.Row(): | |
| include_btn = gr.Button("Include Point", variant="primary", size="sm") | |
| exclude_btn = gr.Button("Exclude Point", variant="secondary", size="sm") | |
| point_mode_text = gr.Textbox(label="Current Mode", value="Include Point (Red)", interactive=False) | |
| # ํฌ์ธํธ ์๋ด | |
| gr.Markdown(""" | |
| **Instructions:** | |
| - Select a layer from dropdown | |
| - Choose point mode (Include/Exclude) | |
| - Click on image to add point | |
| - **Red circle (โ)**: Include this area | |
| - **Blue circle with X**: Exclude this area | |
| """) | |
| # ์ปจํธ๋กค | |
| with gr.Row(): | |
| segment_btn = gr.Button("Run All Segmentation", variant="primary", size="lg") | |
| clear_btn = gr.Button("Clear Current Layer", variant="secondary") | |
| # ์ํ | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| st_layer_manager = gr.State(None) # LayerManager๋ ํจ์ ๋ด์์ ์์ฑ | |
| point_mode_state = gr.State("positive") # "positive" or "negative" | |
| with gr.Column(scale=2): | |
| img_output = gr.Image(type="pil", label="Segmentation Result", height=400, interactive=False) | |
| # ๋ฉด์ ํ ์ด๋ธ | |
| area_table = gr.Dataframe( | |
| label="Area Ratio by Layer" | |
| ) | |
| # ์ค์ - ์์ ํ ์ ๊ฑฐ (API ์คํค๋ง ์ถฉ๋ ๋ฐฉ์ง) | |
| # Visualization settings removed temporarily | |
| # ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
| create_layer_btn.click( | |
| create_new_layer, | |
| inputs=[layer_name_input, st_layer_manager], | |
| outputs=[st_layer_manager, layer_buttons_html, layer_selector, status_text] | |
| ) | |
| # ๋ ์ด์ด ์ ํ | |
| def on_layer_select(selected_name, mgr): | |
| if mgr is None: | |
| mgr = LayerManager() | |
| if selected_name: | |
| # ์ด๋ฆ์ผ๋ก layer_id ์ฐพ๊ธฐ | |
| layer_id = None | |
| for lid, layer in mgr.layers.items(): | |
| if layer['name'] == selected_name: | |
| layer_id = lid | |
| break | |
| if layer_id: | |
| mgr.set_current_layer(layer_id) | |
| return mgr, create_layer_status_html(mgr), f"Layer '{selected_name}' selected" | |
| return mgr, create_layer_status_html(mgr), "Please select a layer" | |
| layer_selector.change( | |
| on_layer_select, | |
| inputs=[layer_selector, st_layer_manager], | |
| outputs=[st_layer_manager, layer_buttons_html, status_text] | |
| ) | |
| # ํฌ์ธํธ ๋ชจ๋ ๋ณ๊ฒฝ | |
| def set_include_mode(): | |
| return "positive", "Include Point (Red)" | |
| def set_exclude_mode(): | |
| return "negative", "Exclude Point (Blue)" | |
| include_btn.click( | |
| set_include_mode, | |
| outputs=[point_mode_state, point_mode_text] | |
| ) | |
| exclude_btn.click( | |
| set_exclude_mode, | |
| outputs=[point_mode_state, point_mode_text] | |
| ) | |
| # ์ด๋ฏธ์ง ํด๋ฆญ ์ด๋ฒคํธ - img_input๊ณผ img_output ๋ชจ๋์์ ํด๋ฆญ ๋ฐ๊ธฐ | |
| img_input.select( | |
| click_on_image, | |
| inputs=[st_layer_manager, img_input, point_mode_state], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text] | |
| ) | |
| img_output.select( | |
| click_on_image, | |
| inputs=[st_layer_manager, img_input, point_mode_state], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text] | |
| ) | |
| # ๋ชจ๋ ๋ ์ด์ด ์ธ๊ทธ๋ฉํ ์ด์ ์คํ | |
| segment_btn.click( | |
| segment_all_layers, | |
| inputs=[st_layer_manager, img_input], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table] | |
| ) | |
| clear_btn.click( | |
| clear_current_layer, | |
| inputs=[st_layer_manager, img_input], | |
| outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table] | |
| ) | |
| # ํฌ๋ช ๋ ๋ฐ ํ ๋๋ฆฌ ์ฌ๋ผ์ด๋ ์ค์๊ฐ ์ ๋ฐ์ดํธ - ์ผ์์ ์ผ๋ก ๋นํ์ฑํ | |
| # opacity_slider.change( | |
| # refresh_visualization, | |
| # inputs=[st_layer_manager, img_input, opacity_slider, border_slider], | |
| # outputs=[img_output, status_text, area_table] | |
| # ) | |
| # border_slider.change( | |
| # refresh_visualization, | |
| # inputs=[st_layer_manager, img_input, opacity_slider, border_slider], | |
| # outputs=[img_output, status_text, area_table] | |
| # ) | |
| # ์ด๋ฏธ์ง ์ ๋ก๋ ์ ์ด๊ธฐํ | |
| def on_image_upload(img): | |
| new_manager = LayerManager() | |
| empty_html = "<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>" | |
| # ์ ๋ก๋ํ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅ์๋ ํ์ | |
| return new_manager, img, pd.DataFrame(), empty_html, update_layer_selector_choices(new_manager), "positive", "Include Point (Red)", "New image uploaded" | |
| img_input.change( | |
| on_image_upload, | |
| inputs=[img_input], | |
| outputs=[st_layer_manager, img_output, area_table, layer_buttons_html, layer_selector, point_mode_state, point_mode_text, status_text] | |
| ) | |
| def run_spaces(): | |
| _launch_compat(demo, ssr=False) | |
| def run_local(): | |
| _launch_compat(demo, show_error=True, ssr=False) | |
| if __name__ == "__main__": | |
| # Hugging Spaces ํ๊ฒฝ ๊ฐ์ง | |
| import os | |
| if os.getenv("SPACE_ID"): | |
| run_spaces() | |
| else: | |
| run_local() | |