import os import cv2 import tempfile import spaces import gradio as gr import numpy as np import torch import matplotlib import matplotlib.pyplot as plt import pandas as pd from PIL import Image, ImageDraw import inspect from transformers import ( Sam3Model, Sam3Processor, Sam3TrackerModel, Sam3TrackerProcessor ) # ============ GRADIO CLIENT COMPAT PATCH ============ # Some gradio_client versions crash when encountering boolean JSON Schema nodes # (e.g., additionalProperties: false/true) while generating /info API schema. # Patch defensively to prevent startup failures on HF Spaces. try: import gradio_client.utils as _grc_utils if not getattr(_grc_utils, "_BOOL_SCHEMA_PATCHED", False): # 1) get_type() sometimes assumes schema is a dict _orig_get_type = getattr(_grc_utils, "get_type", None) def _safe_get_type(schema): # noqa: ANN001 if isinstance(schema, bool) or not isinstance(schema, dict): return "any" if _orig_get_type is None: return "any" return _orig_get_type(schema) if _orig_get_type is not None: _grc_utils.get_type = _safe_get_type # 2) _json_schema_to_python_type() may raise on schema == True/False _orig_js2pt = getattr(_grc_utils, "_json_schema_to_python_type", None) def _safe_json_schema_to_python_type(schema, defs=None): # noqa: ANN001 if isinstance(schema, bool): # boolean JSON Schema: True means "any", False means "no value". # For API-info rendering, treating both as Any avoids crashes. return "Any" if _orig_js2pt is None: return "Any" return _orig_js2pt(schema, defs) if _orig_js2pt is not None: _grc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type _grc_utils._BOOL_SCHEMA_PATCHED = True except Exception as _e: # If gradio_client is unavailable or API changes, ignore and proceed. print(f"[warn] gradio_client schema patch skipped: {_e}") # ============ GRADIO LAUNCH COMPAT ============ def _launch_compat(demo, **kwargs): """ Gradio 버전별로 Blocks.launch() 시그니처가 달라서(예: ssr 지원 여부), 현재 설치된 Gradio가 지원하는 키만 골라 launch()를 호출한다. """ # 먼저 원하는 옵션으로 그대로 호출하고, "unexpected keyword"가 나면 해당 키만 제거 후 재시도 try: return demo.launch(**kwargs) except TypeError as e: msg = str(e) if "unexpected keyword argument 'ssr'" in msg or 'unexpected keyword argument "ssr"' in msg: kwargs.pop("ssr", None) return demo.launch(**kwargs) # 그 외 케이스는 시그니처 기반으로 한 번 더 보수적으로 필터링 try: sig = inspect.signature(demo.launch) supported = set(sig.parameters.keys()) filtered = {k: v for k, v in kwargs.items() if k in supported} return demo.launch(**filtered) except Exception: raise # ============ THEME SETUP ============ # Theme disabled to avoid API schema issues on HF Spaces # ============ GLOBAL SETUP ============ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🖥️ Using compute device: {device}") # Models will be loaded lazily in functions to avoid build timeouts IMG_MODEL = None IMG_PROCESSOR = None TRK_MODEL = None TRK_PROCESSOR = None @spaces.GPU def load_models(): """Lazy load models when needed""" global IMG_MODEL, IMG_PROCESSOR, TRK_MODEL, TRK_PROCESSOR if IMG_MODEL is not None: return True print("⏳ Loading SAM3 Models...") try: # GPU가 사용 가능하면 GPU로 로드 device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype) IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3") TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype) TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3") print(f"✅ All Models loaded successfully on {device}!") return True except Exception as e: print(f"❌ Model loading failed: {e}") return False # ============ LAYER MANAGEMENT ============ class LayerManager: """레이어 기반 세그멘테이션 관리 클래스""" def __init__(self): self.layers = {} # layer_id -> {'name': str, 'color': tuple, 'points': list, 'point_labels': list, 'masks': list, 'area': float} self.current_layer_id = None self.layer_counter = 0 def create_layer(self, name, color=None): """새 레이어 생성""" if color is None: # 무작위 색상 생성 import random color = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200)) layer_id = f"layer_{self.layer_counter}" self.layers[layer_id] = { 'name': name, 'color': color, 'points': [], 'point_labels': [], # 1: positive, 0: negative 'masks': [], 'area': 0.0 } self.layer_counter += 1 return layer_id def add_point_to_layer(self, layer_id, point, label=1): """레이어에 포인트 추가""" if layer_id in self.layers: self.layers[layer_id]['points'].append(point) self.layers[layer_id]['point_labels'].append(label) print(f"[add_point_to_layer] Added point to '{self.layers[layer_id]['name']}': {point}, label={label}") print(f"[add_point_to_layer] Total points in '{self.layers[layer_id]['name']}': {len(self.layers[layer_id]['points'])}") def add_mask_to_layer(self, layer_id, mask): """레이어에 마스크 추가""" if layer_id in self.layers: # 기존 마스크를 교체 (같은 레이어에 재세그멘테이션 시) self.layers[layer_id]['masks'] = [mask] # 면적 계산 - mask를 numpy array로 변환 if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = mask # 면적 계산 area = np.sum(mask_np > 0) self.layers[layer_id]['area'] = area # 디버깅: 마스크 정보 출력 print(f"[add_mask_to_layer] Layer: {self.layers[layer_id]['name']}, Mask shape: {mask_np.shape}, Area: {area}") def get_current_layer(self): """현재 선택된 레이어 반환""" if self.current_layer_id and self.current_layer_id in self.layers: return self.layers[self.current_layer_id] return None def set_current_layer(self, layer_id): """현재 레이어 설정""" self.current_layer_id = layer_id def clear_current_layer(self): """현재 레이어 초기화""" if self.current_layer_id and self.current_layer_id in self.layers: self.layers[self.current_layer_id]['points'] = [] self.layers[self.current_layer_id]['point_labels'] = [] self.layers[self.current_layer_id]['masks'] = [] self.layers[self.current_layer_id]['area'] = 0.0 def calculate_total_area_ratio(layer_manager, total_pixels): """전체 이미지 대비 각 레이어의 면적 비율 계산""" ratios = [] for layer_id, layer in layer_manager.layers.items(): area = layer['area'] ratio = (area / total_pixels) * 100 if total_pixels > 0 and area > 0 else 0 has_mask = len(layer['masks']) > 0 # 디버깅: 레이어 정보 출력 print(f"[calculate_total_area_ratio] Layer: {layer['name']}, Area: {area}, Ratio: {ratio}%, Masks: {len(layer['masks'])}, Has mask: {has_mask}") ratios.append({ 'layer_name': layer['name'], 'area_pixels': int(area), 'ratio_percent': round(ratio, 2) }) return ratios def create_area_chart_data(ratios): """면적 데이터를 테이블 포맷으로 변환""" if not ratios: return pd.DataFrame(columns=["Layer", "Area (pixels)", "Ratio(%)"]) data = [] for ratio in ratios: data.append({ "Layer": ratio['layer_name'], "Area (pixels)": f"{ratio['area_pixels']:,}", "Ratio(%)": f"{ratio['ratio_percent']}%" }) return pd.DataFrame(data) # ============ UTILITY FUNCTIONS ============ def compose_all_layers(base_image, layer_manager, opacity=0.5, border_width=2): """모든 레이어를 합성하여 최종 이미지 생성""" if isinstance(base_image, np.ndarray): base_image = Image.fromarray(base_image) base_image = base_image.convert("RGBA") if not layer_manager.layers: return base_image.convert("RGB") composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0)) for layer_id, layer in layer_manager.layers.items(): if not layer['masks']: continue layer_color = layer['color'] for mask in layer['masks']: if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() mask = mask.astype(np.uint8) if mask.ndim == 3: mask = mask[0] if mask.ndim == 2 and mask.shape[0] == 1: mask = mask[0] # 마스크를 PIL 이미지로 변환 mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # 색상 레이어 생성 color_layer = Image.new("RGBA", base_image.size, layer_color + (0,)) mask_alpha = mask_img.point(lambda v: int(v * opacity * 255) if v > 0 else 0) color_layer.putalpha(mask_alpha) # 테두리 추가 if border_width > 0: try: # 마스크의 테두리 찾기 mask_np = np.array(mask_img) kernel_size = border_width * 2 + 1 dilated = cv2.dilate(mask_np, np.ones((kernel_size, kernel_size), np.uint8)) border = dilated - mask_np border_img = Image.fromarray(border) border_layer = Image.new("RGBA", base_image.size, (255, 255, 255, 255)) # 흰색 테두리 border_alpha = border_img.point(lambda v: 255 if v > 0 else 0) border_layer.putalpha(border_alpha) # 테두리를 먼저 합성 composite_layer = Image.alpha_composite(composite_layer, border_layer) except Exception as e: print(f"Border creation error: {e}") # 마스크 레이어 합성 composite_layer = Image.alpha_composite(composite_layer, color_layer) # 최종 합성 final_result = Image.alpha_composite(base_image, composite_layer) return final_result.convert("RGB") def draw_points_on_image(image, layer_manager): """이미지에 모든 레이어의 포인트들을 표시""" if isinstance(image, np.ndarray): image = Image.fromarray(image) draw_img = image.copy() draw = ImageDraw.Draw(draw_img) for layer_id, layer in layer_manager.layers.items(): is_current = (layer_id == layer_manager.current_layer_id) for i, point in enumerate(layer['points']): x, y = point label = layer['point_labels'][i] # 포지티브: 빨간색 원, 네거티브: 파란색 X표시 if label == 1: # Positive # 큰 빨간색 원 r = 15 if is_current else 10 draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=3) # 작은 흰색 원 (중앙) draw.ellipse((x-3, y-3, x+3, y+3), fill="white") else: # Negative (0) # 큰 파란색 원 r = 15 if is_current else 10 draw.ellipse((x-r, y-r, x+r, y+r), fill="blue", outline="white", width=3) # X 표시 line_length = 8 draw.line([(x-line_length, y-line_length), (x+line_length, y+line_length)], fill="white", width=3) draw.line([(x-line_length, y+line_length), (x+line_length, y-line_length)], fill="white", width=3) return draw_img # ============ UI FUNCTIONS ============ def update_layer_selector_choices(manager): """레이어 선택 라디오 버튼의 choices 업데이트""" choices = [layer['name'] for layer in manager.layers.values()] current_value = None if manager.current_layer_id and manager.current_layer_id in manager.layers: current_value = manager.layers[manager.current_layer_id]['name'] # 컴포넌트를 새로 생성해서 반환하지 말고, gr.update()로 기존 컴포넌트를 업데이트해야 함 return gr.update(choices=choices, value=current_value, interactive=True) def create_new_layer(name, current_manager): """새 레이어 생성""" if current_manager is None: current_manager = LayerManager() if not name.strip(): return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), "Please enter a layer name" # 중복 이름 체크 for layer_id, layer in current_manager.layers.items(): if layer['name'] == name.strip(): return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), f"Layer name '{name}' already exists" layer_id = current_manager.create_layer(name.strip()) current_manager.set_current_layer(layer_id) return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), f"Layer '{name}' created" def create_layer_status_html(current_manager): """레이어 상태 표시 HTML 생성 (시각적 표시만)""" if not current_manager.layers: return "