U-Net for Gastrointestinal Polyp Segmentation (Multi-Class)

Multi-class (3-class) segmentation model trained on the Kvasir-SEG dataset.

See also the binary segmentation variant.

Architecture

Same U-Net encoder-decoder as the binary variant, with output changed to 3 channels:

  • Encoder: 4 stages (64 β†’ 128 β†’ 256 β†’ 512) + 1024-channel bottleneck.
  • Decoder: 4 upsampling stages with skip connections (512 β†’ 256 β†’ 128 β†’ 64).
  • Reduction head: Conv2d(64β†’32) + BatchNorm + ReLU before the final 1Γ—1 conv.
  • Output: 3-channel logits (multi-class segmentation).
  • Input: 3-channel RGB, 256Γ—256.

Loss Function

CombinedLoss = 0.5 Γ— CrossEntropy + 0.5 Γ— Multi-class Dice Loss

  • CrossEntropy: replaces BCE for the multi-class case, providing per-pixel classification gradients across all classes.
  • Multi-class Dice Loss: computes per-class Dice using softmax probabilities and one-hot targets, then averages across classes. Ensures balanced optimization even with class imbalance.

Training Setup

Parameter Value
Epochs 20
Optimizer Adam
Learning rate 1e-4
LR scheduler ReduceLROnPlateau (factor=0.5, patience=3, min_lr=1e-6)
Batch size 8
Image size 256Γ—256
GPU NVIDIA GeForce RTX 3080 Ti
Dataset splits train=800, val=100, test=100

Metrics

Split Dice Coefficient Loss
Validation (best) 0.9036 β€”
Test 0.9169 0.3322

Usage

import torch
from model import UNet
from safetensors.torch import load_file

model = UNet(in_channels=3, num_classes=3)
model.load_state_dict(load_file("model.safetensors"))
model.eval()

# input: (B, 3, 256, 256) float tensor in [0, 1]
# output: (B, 3, 256, 256) logits β€” apply argmax(dim=1) for class map

Or download from Hub:

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from model import UNet

weights = hf_hub_download(
    repo_id="sebastiao-teixeira/week04-polyp-segmentation-unet-multiclass",
    filename="model.safetensors",
)
model = UNet(in_channels=3, num_classes=3)
model.load_state_dict(load_file(weights))
model.eval()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train sebastiao-teixeira/week04-polyp-segmentation-unet-multiclass