import os import cv2 import tempfile import spaces import gradio as gr import numpy as np import torch import matplotlib import matplotlib.pyplot as plt import pandas as pd from PIL import Image, ImageDraw import inspect from transformers import ( Sam3Model, Sam3Processor, Sam3TrackerModel, Sam3TrackerProcessor ) # ============ GRADIO CLIENT COMPAT PATCH ============ # Some gradio_client versions crash when encountering boolean JSON Schema nodes # (e.g., additionalProperties: false/true) while generating /info API schema. # Patch defensively to prevent startup failures on HF Spaces. try: import gradio_client.utils as _grc_utils if not getattr(_grc_utils, "_BOOL_SCHEMA_PATCHED", False): # 1) get_type() sometimes assumes schema is a dict _orig_get_type = getattr(_grc_utils, "get_type", None) def _safe_get_type(schema): # noqa: ANN001 if isinstance(schema, bool) or not isinstance(schema, dict): return "any" if _orig_get_type is None: return "any" return _orig_get_type(schema) if _orig_get_type is not None: _grc_utils.get_type = _safe_get_type # 2) _json_schema_to_python_type() may raise on schema == True/False _orig_js2pt = getattr(_grc_utils, "_json_schema_to_python_type", None) def _safe_json_schema_to_python_type(schema, defs=None): # noqa: ANN001 if isinstance(schema, bool): # boolean JSON Schema: True means "any", False means "no value". # For API-info rendering, treating both as Any avoids crashes. return "Any" if _orig_js2pt is None: return "Any" return _orig_js2pt(schema, defs) if _orig_js2pt is not None: _grc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type _grc_utils._BOOL_SCHEMA_PATCHED = True except Exception as _e: # If gradio_client is unavailable or API changes, ignore and proceed. print(f"[warn] gradio_client schema patch skipped: {_e}") # ============ GRADIO LAUNCH COMPAT ============ def _launch_compat(demo, **kwargs): """ Gradio 버전별로 Blocks.launch() 시그니처가 달라서(예: ssr 지원 여부), 현재 설치된 Gradio가 지원하는 키만 골라 launch()를 호출한다. """ # 먼저 원하는 옵션으로 그대로 호출하고, "unexpected keyword"가 나면 해당 키만 제거 후 재시도 try: return demo.launch(**kwargs) except TypeError as e: msg = str(e) if "unexpected keyword argument 'ssr'" in msg or 'unexpected keyword argument "ssr"' in msg: kwargs.pop("ssr", None) return demo.launch(**kwargs) # 그 외 케이스는 시그니처 기반으로 한 번 더 보수적으로 필터링 try: sig = inspect.signature(demo.launch) supported = set(sig.parameters.keys()) filtered = {k: v for k, v in kwargs.items() if k in supported} return demo.launch(**filtered) except Exception: raise # ============ THEME SETUP ============ # Theme disabled to avoid API schema issues on HF Spaces # ============ GLOBAL SETUP ============ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🖥️ Using compute device: {device}") # Models will be loaded lazily in functions to avoid build timeouts IMG_MODEL = None IMG_PROCESSOR = None TRK_MODEL = None TRK_PROCESSOR = None @spaces.GPU def load_models(): """Lazy load models when needed""" global IMG_MODEL, IMG_PROCESSOR, TRK_MODEL, TRK_PROCESSOR if IMG_MODEL is not None: return True print("⏳ Loading SAM3 Models...") try: # GPU가 사용 가능하면 GPU로 로드 device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype) IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3") TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype) TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3") print(f"✅ All Models loaded successfully on {device}!") return True except Exception as e: print(f"❌ Model loading failed: {e}") return False # ============ LAYER MANAGEMENT ============ class LayerManager: """레이어 기반 세그멘테이션 관리 클래스""" def __init__(self): self.layers = {} # layer_id -> {'name': str, 'color': tuple, 'points': list, 'point_labels': list, 'masks': list, 'area': float} self.current_layer_id = None self.layer_counter = 0 def create_layer(self, name, color=None): """새 레이어 생성""" if color is None: # 무작위 색상 생성 import random color = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200)) layer_id = f"layer_{self.layer_counter}" self.layers[layer_id] = { 'name': name, 'color': color, 'points': [], 'point_labels': [], # 1: positive, 0: negative 'masks': [], 'area': 0.0 } self.layer_counter += 1 return layer_id def add_point_to_layer(self, layer_id, point, label=1): """레이어에 포인트 추가""" if layer_id in self.layers: self.layers[layer_id]['points'].append(point) self.layers[layer_id]['point_labels'].append(label) print(f"[add_point_to_layer] Added point to '{self.layers[layer_id]['name']}': {point}, label={label}") print(f"[add_point_to_layer] Total points in '{self.layers[layer_id]['name']}': {len(self.layers[layer_id]['points'])}") def add_mask_to_layer(self, layer_id, mask): """레이어에 마스크 추가""" if layer_id in self.layers: # 기존 마스크를 교체 (같은 레이어에 재세그멘테이션 시) self.layers[layer_id]['masks'] = [mask] # 면적 계산 - mask를 numpy array로 변환 if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = mask # 면적 계산 area = np.sum(mask_np > 0) self.layers[layer_id]['area'] = area # 디버깅: 마스크 정보 출력 print(f"[add_mask_to_layer] Layer: {self.layers[layer_id]['name']}, Mask shape: {mask_np.shape}, Area: {area}") def get_current_layer(self): """현재 선택된 레이어 반환""" if self.current_layer_id and self.current_layer_id in self.layers: return self.layers[self.current_layer_id] return None def set_current_layer(self, layer_id): """현재 레이어 설정""" self.current_layer_id = layer_id def clear_current_layer(self): """현재 레이어 초기화""" if self.current_layer_id and self.current_layer_id in self.layers: self.layers[self.current_layer_id]['points'] = [] self.layers[self.current_layer_id]['point_labels'] = [] self.layers[self.current_layer_id]['masks'] = [] self.layers[self.current_layer_id]['area'] = 0.0 def calculate_total_area_ratio(layer_manager, total_pixels): """전체 이미지 대비 각 레이어의 면적 비율 계산""" ratios = [] for layer_id, layer in layer_manager.layers.items(): area = layer['area'] ratio = (area / total_pixels) * 100 if total_pixels > 0 and area > 0 else 0 has_mask = len(layer['masks']) > 0 # 디버깅: 레이어 정보 출력 print(f"[calculate_total_area_ratio] Layer: {layer['name']}, Area: {area}, Ratio: {ratio}%, Masks: {len(layer['masks'])}, Has mask: {has_mask}") ratios.append({ 'layer_name': layer['name'], 'area_pixels': int(area), 'ratio_percent': round(ratio, 2) }) return ratios def create_area_chart_data(ratios): """면적 데이터를 테이블 포맷으로 변환""" if not ratios: return pd.DataFrame(columns=["Layer", "Area (pixels)", "Ratio(%)"]) data = [] for ratio in ratios: data.append({ "Layer": ratio['layer_name'], "Area (pixels)": f"{ratio['area_pixels']:,}", "Ratio(%)": f"{ratio['ratio_percent']}%" }) return pd.DataFrame(data) # ============ UTILITY FUNCTIONS ============ def compose_all_layers(base_image, layer_manager, opacity=0.5, border_width=2): """모든 레이어를 합성하여 최종 이미지 생성""" if isinstance(base_image, np.ndarray): base_image = Image.fromarray(base_image) base_image = base_image.convert("RGBA") if not layer_manager.layers: return base_image.convert("RGB") composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0)) for layer_id, layer in layer_manager.layers.items(): if not layer['masks']: continue layer_color = layer['color'] for mask in layer['masks']: if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() mask = mask.astype(np.uint8) if mask.ndim == 3: mask = mask[0] if mask.ndim == 2 and mask.shape[0] == 1: mask = mask[0] # 마스크를 PIL 이미지로 변환 mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # 색상 레이어 생성 color_layer = Image.new("RGBA", base_image.size, layer_color + (0,)) mask_alpha = mask_img.point(lambda v: int(v * opacity * 255) if v > 0 else 0) color_layer.putalpha(mask_alpha) # 테두리 추가 if border_width > 0: try: # 마스크의 테두리 찾기 mask_np = np.array(mask_img) kernel_size = border_width * 2 + 1 dilated = cv2.dilate(mask_np, np.ones((kernel_size, kernel_size), np.uint8)) border = dilated - mask_np border_img = Image.fromarray(border) border_layer = Image.new("RGBA", base_image.size, (255, 255, 255, 255)) # 흰색 테두리 border_alpha = border_img.point(lambda v: 255 if v > 0 else 0) border_layer.putalpha(border_alpha) # 테두리를 먼저 합성 composite_layer = Image.alpha_composite(composite_layer, border_layer) except Exception as e: print(f"Border creation error: {e}") # 마스크 레이어 합성 composite_layer = Image.alpha_composite(composite_layer, color_layer) # 최종 합성 final_result = Image.alpha_composite(base_image, composite_layer) return final_result.convert("RGB") def draw_points_on_image(image, layer_manager): """이미지에 모든 레이어의 포인트들을 표시""" if isinstance(image, np.ndarray): image = Image.fromarray(image) draw_img = image.copy() draw = ImageDraw.Draw(draw_img) for layer_id, layer in layer_manager.layers.items(): is_current = (layer_id == layer_manager.current_layer_id) for i, point in enumerate(layer['points']): x, y = point label = layer['point_labels'][i] # 포지티브: 빨간색 원, 네거티브: 파란색 X표시 if label == 1: # Positive # 큰 빨간색 원 r = 15 if is_current else 10 draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=3) # 작은 흰색 원 (중앙) draw.ellipse((x-3, y-3, x+3, y+3), fill="white") else: # Negative (0) # 큰 파란색 원 r = 15 if is_current else 10 draw.ellipse((x-r, y-r, x+r, y+r), fill="blue", outline="white", width=3) # X 표시 line_length = 8 draw.line([(x-line_length, y-line_length), (x+line_length, y+line_length)], fill="white", width=3) draw.line([(x-line_length, y+line_length), (x+line_length, y-line_length)], fill="white", width=3) return draw_img # ============ UI FUNCTIONS ============ def update_layer_selector_choices(manager): """레이어 선택 라디오 버튼의 choices 업데이트""" choices = [layer['name'] for layer in manager.layers.values()] current_value = None if manager.current_layer_id and manager.current_layer_id in manager.layers: current_value = manager.layers[manager.current_layer_id]['name'] # 컴포넌트를 새로 생성해서 반환하지 말고, gr.update()로 기존 컴포넌트를 업데이트해야 함 return gr.update(choices=choices, value=current_value, interactive=True) def create_new_layer(name, current_manager): """새 레이어 생성""" if current_manager is None: current_manager = LayerManager() if not name.strip(): return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), "Please enter a layer name" # 중복 이름 체크 for layer_id, layer in current_manager.layers.items(): if layer['name'] == name.strip(): return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), f"Layer name '{name}' already exists" layer_id = current_manager.create_layer(name.strip()) current_manager.set_current_layer(layer_id) return current_manager, create_layer_status_html(current_manager), update_layer_selector_choices(current_manager), f"Layer '{name}' created" def create_layer_status_html(current_manager): """레이어 상태 표시 HTML 생성 (시각적 표시만)""" if not current_manager.layers: return "
No layers created
" html = "
" for layer_id, layer in current_manager.layers.items(): is_active = (current_manager.current_layer_id == layer_id) # 색상 추출 r, g, b = layer['color'] color_hex = f"#{r:02x}{g:02x}{b:02x}" # 활성화 상태에 따른 스타일 if is_active: style = f""" background: linear-gradient(135deg, {color_hex}, {color_hex}dd); color: white; border: 3px solid #4682B4; box-shadow: 0 4px 12px rgba(70, 130, 180, 0.4); """ else: style = f""" background: linear-gradient(135deg, {color_hex}aa, {color_hex}77); color: white; border: 2px solid {color_hex}; opacity: 0.7; """ # 포인트 개수 계산 (포지티브/네거티브 구분) positive_points = sum(1 for label in layer['point_labels'] if label == 1) negative_points = sum(1 for label in layer['point_labels'] if label == 0) masks_count = len(layer['masks']) has_mask = masks_count > 0 # 상태 아이콘 status_icon = "[OK]" if has_mask else "[ ]" html += f"""
{status_icon} {layer['name']}
+{positive_points} -{negative_points} {masks_count}mask
""" html += "
" return html def click_on_image(current_manager, image, point_mode, evt: gr.SelectData): """이미지 클릭 처리 - Include/Exclude 모드에 따라 포인트 추가""" if current_manager is None: current_manager = LayerManager() if image is None or current_manager.current_layer_id is None: return image, current_manager, create_layer_status_html(current_manager), "Please select image and layer" x, y = evt.index # 포인트 모드에 따라 레이블 결정 (positive=1, negative=0) label = 1 if point_mode == "positive" else 0 layer_name = current_manager.layers[current_manager.current_layer_id]['name'] print(f"\n[click_on_image] ================") print(f"[click_on_image] Layer: {layer_name}") print(f"[click_on_image] Point mode: {point_mode}, Label: {label}, Position: ({x}, {y})") current_manager.add_point_to_layer(current_manager.current_layer_id, [x, y], label) # 포인트 표시된 이미지 생성 (원본 이미지에 포인트 표시) result_image = draw_points_on_image(image, current_manager) mode_text = "Include" if label == 1 else "Exclude" return result_image, current_manager, create_layer_status_html(current_manager), f"{mode_text} point added to '{layer_name}' at ({x}, {y})" def segment_all_layers(current_manager, image): """모든 레이어를 순서대로 세그멘테이션 실행""" if current_manager is None: current_manager = LayerManager() if image is None: return None, current_manager, create_layer_status_html(current_manager), "Please upload an image", pd.DataFrame() if not current_manager.layers: return None, current_manager, create_layer_status_html(current_manager), "Please create layers first", pd.DataFrame() try: print(f"\n[segment_all_layers] Starting segmentation for all layers...") segmented_count = 0 skipped_count = 0 # 모든 레이어를 순회하며 세그멘테이션 for layer_id, layer in current_manager.layers.items(): layer_name = layer['name'] # 포인트가 없는 레이어는 건너뛰기 if not layer['points']: print(f"[segment_all_layers] Skipping '{layer_name}' - no points") skipped_count += 1 continue print(f"\n[segment_all_layers] Processing layer: {layer_name}") print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}") # Load models if needed if not load_models(): print(f"[segment_all_layers] Failed to load models for layer: {layer_name}") continue # SAM3 Tracker로 세그멘테이션 points_list = layer['points'] labels_list = layer['point_labels'] input_points = [[points_list]] input_labels = [[labels_list]] # Use the same device as the model model_device = next(TRK_MODEL.parameters()).device inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model_device) with torch.no_grad(): outputs = TRK_MODEL(**inputs, multimask_output=False) masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0] # 레이어에 마스크 추가 current_manager.add_mask_to_layer(layer_id, masks[0]) segmented_count += 1 print(f"[segment_all_layers] Completed '{layer_name}'") # 결과 이미지 생성 (포인트 포함) result_image = compose_all_layers(image, current_manager, 0.5, 2) # 기본값 사용 result_image = draw_points_on_image(result_image, current_manager) # 면적 분석 total_pixels = image.size[0] * image.size[1] ratios = calculate_total_area_ratio(current_manager, total_pixels) chart_data = create_area_chart_data(ratios) status_msg = f"Segmentation completed! Processed: {segmented_count} layers, Skipped: {skipped_count} layers" print(f"\n[segment_all_layers] {status_msg}") return result_image, current_manager, create_layer_status_html(current_manager), status_msg, chart_data except Exception as e: import traceback print(f"[segment_all_layers] Error: {str(e)}") traceback.print_exc() return None, current_manager, create_layer_status_html(current_manager), f"Error: {str(e)}", pd.DataFrame() def clear_current_layer(current_manager, image): """현재 레이어 초기화""" if current_manager is None: current_manager = LayerManager() if current_manager.current_layer_id: current_manager.clear_current_layer() if image: result_image = compose_all_layers(image, current_manager, 0.5, 2) # 기본값 사용 result_image = draw_points_on_image(result_image, current_manager) else: result_image = None total_pixels = image.size[0] * image.size[1] if image else 0 ratios = calculate_total_area_ratio(current_manager, total_pixels) chart_data = create_area_chart_data(ratios) return result_image, current_manager, create_layer_status_html(current_manager), "Layer cleared", chart_data return None, current_manager, create_layer_status_html(current_manager), "Please select a layer", pd.DataFrame() def refresh_visualization(current_manager, image, opacity, border_width): """시각화 새로고침""" if current_manager is None: current_manager = LayerManager() if image is None: return None, "Please upload an image", pd.DataFrame() result_image = compose_all_layers(image, current_manager, opacity, border_width) result_image = draw_points_on_image(result_image, current_manager) total_pixels = image.size[0] * image.size[1] ratios = calculate_total_area_ratio(current_manager, total_pixels) chart_data = create_area_chart_data(ratios) return result_image, "Visualization updated", chart_data # ============ GRADIO INTERFACE ============ # No custom JavaScript needed anymore custom_js = "" with gr.Blocks() as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# **SAM3 Layer Segmentation Tool**", elem_id="main-title") gr.Markdown("**Layer-based object separation and area analysis tool** | 1. Create layers 2. Select point mode and click 3. Run segmentation (processes all layers)") with gr.Row(): with gr.Column(scale=1): img_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=400) # 레이어 생성 with gr.Row(): layer_name_input = gr.Textbox(label="Layer Name", placeholder="e.g. bench, tree, person") create_layer_btn = gr.Button("Create", variant="primary") # 레이어 상태 표시 gr.Markdown("### Layers Status") layer_buttons_html = gr.HTML("
No layers created
") # 레이어 선택 (라디오 버튼으로 변경) layer_selector = gr.Radio(label="Select Layer to Add Points", choices=[], interactive=True, value=None) # 포인트 모드 선택 gr.Markdown("### Point Mode") with gr.Row(): include_btn = gr.Button("Include Point", variant="primary", size="sm") exclude_btn = gr.Button("Exclude Point", variant="secondary", size="sm") point_mode_text = gr.Textbox(label="Current Mode", value="Include Point (Red)", interactive=False) # 포인트 안내 gr.Markdown(""" **Instructions:** - Select a layer from dropdown - Choose point mode (Include/Exclude) - Click on image to add point - **Red circle (●)**: Include this area - **Blue circle with X**: Exclude this area """) # 컨트롤 with gr.Row(): segment_btn = gr.Button("Run All Segmentation", variant="primary", size="lg") clear_btn = gr.Button("Clear Current Layer", variant="secondary") # 상태 status_text = gr.Textbox(label="Status", interactive=False) st_layer_manager = gr.State(None) # LayerManager는 함수 내에서 생성 point_mode_state = gr.State("positive") # "positive" or "negative" with gr.Column(scale=2): img_output = gr.Image(type="pil", label="Segmentation Result", height=400, interactive=False) # 면적 테이블 area_table = gr.Dataframe( label="Area Ratio by Layer" ) # 설정 - 완전히 제거 (API 스키마 충돌 방지) # Visualization settings removed temporarily # 이벤트 연결 create_layer_btn.click( create_new_layer, inputs=[layer_name_input, st_layer_manager], outputs=[st_layer_manager, layer_buttons_html, layer_selector, status_text] ) # 레이어 선택 def on_layer_select(selected_name, mgr): if mgr is None: mgr = LayerManager() if selected_name: # 이름으로 layer_id 찾기 layer_id = None for lid, layer in mgr.layers.items(): if layer['name'] == selected_name: layer_id = lid break if layer_id: mgr.set_current_layer(layer_id) return mgr, create_layer_status_html(mgr), f"Layer '{selected_name}' selected" return mgr, create_layer_status_html(mgr), "Please select a layer" layer_selector.change( on_layer_select, inputs=[layer_selector, st_layer_manager], outputs=[st_layer_manager, layer_buttons_html, status_text] ) # 포인트 모드 변경 def set_include_mode(): return "positive", "Include Point (Red)" def set_exclude_mode(): return "negative", "Exclude Point (Blue)" include_btn.click( set_include_mode, outputs=[point_mode_state, point_mode_text] ) exclude_btn.click( set_exclude_mode, outputs=[point_mode_state, point_mode_text] ) # 이미지 클릭 이벤트 - img_input과 img_output 모두에서 클릭 받기 img_input.select( click_on_image, inputs=[st_layer_manager, img_input, point_mode_state], outputs=[img_output, st_layer_manager, layer_buttons_html, status_text] ) img_output.select( click_on_image, inputs=[st_layer_manager, img_input, point_mode_state], outputs=[img_output, st_layer_manager, layer_buttons_html, status_text] ) # 모든 레이어 세그멘테이션 실행 segment_btn.click( segment_all_layers, inputs=[st_layer_manager, img_input], outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table] ) clear_btn.click( clear_current_layer, inputs=[st_layer_manager, img_input], outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table] ) # 투명도 및 테두리 슬라이더 실시간 업데이트 - 일시적으로 비활성화 # opacity_slider.change( # refresh_visualization, # inputs=[st_layer_manager, img_input, opacity_slider, border_slider], # outputs=[img_output, status_text, area_table] # ) # border_slider.change( # refresh_visualization, # inputs=[st_layer_manager, img_input, opacity_slider, border_slider], # outputs=[img_output, status_text, area_table] # ) # 이미지 업로드 시 초기화 def on_image_upload(img): new_manager = LayerManager() empty_html = "
No layers created
" # 업로드한 이미지를 출력에도 표시 return new_manager, img, pd.DataFrame(), empty_html, update_layer_selector_choices(new_manager), "positive", "Include Point (Red)", "New image uploaded" img_input.change( on_image_upload, inputs=[img_input], outputs=[st_layer_manager, img_output, area_table, layer_buttons_html, layer_selector, point_mode_state, point_mode_text, status_text] ) @spaces.GPU def run_spaces(): _launch_compat(demo, ssr=False) def run_local(): _launch_compat(demo, show_error=True, ssr=False) if __name__ == "__main__": # Hugging Spaces 환경 감지 import os if os.getenv("SPACE_ID"): run_spaces() else: run_local()