U-Net for Gastrointestinal Polyp Segmentation (Multi-Class)
Multi-class (3-class) segmentation model trained on the Kvasir-SEG dataset.
See also the binary segmentation variant.
Architecture
Same U-Net encoder-decoder as the binary variant, with output changed to 3 channels:
- Encoder: 4 stages (64 β 128 β 256 β 512) + 1024-channel bottleneck.
- Decoder: 4 upsampling stages with skip connections (512 β 256 β 128 β 64).
- Reduction head: Conv2d(64β32) + BatchNorm + ReLU before the final 1Γ1 conv.
- Output: 3-channel logits (multi-class segmentation).
- Input: 3-channel RGB, 256Γ256.
Loss Function
CombinedLoss = 0.5 Γ CrossEntropy + 0.5 Γ Multi-class Dice Loss
- CrossEntropy: replaces BCE for the multi-class case, providing per-pixel classification gradients across all classes.
- Multi-class Dice Loss: computes per-class Dice using softmax probabilities and one-hot targets, then averages across classes. Ensures balanced optimization even with class imbalance.
Training Setup
| Parameter | Value |
|---|---|
| Epochs | 20 |
| Optimizer | Adam |
| Learning rate | 1e-4 |
| LR scheduler | ReduceLROnPlateau (factor=0.5, patience=3, min_lr=1e-6) |
| Batch size | 8 |
| Image size | 256Γ256 |
| GPU | NVIDIA GeForce RTX 3080 Ti |
| Dataset splits | train=800, val=100, test=100 |
Metrics
| Split | Dice Coefficient | Loss |
|---|---|---|
| Validation (best) | 0.9036 | β |
| Test | 0.9169 | 0.3322 |
Usage
import torch
from model import UNet
from safetensors.torch import load_file
model = UNet(in_channels=3, num_classes=3)
model.load_state_dict(load_file("model.safetensors"))
model.eval()
# input: (B, 3, 256, 256) float tensor in [0, 1]
# output: (B, 3, 256, 256) logits β apply argmax(dim=1) for class map
Or download from Hub:
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from model import UNet
weights = hf_hub_download(
repo_id="sebastiao-teixeira/week04-polyp-segmentation-unet-multiclass",
filename="model.safetensors",
)
model = UNet(in_channels=3, num_classes=3)
model.load_state_dict(load_file(weights))
model.eval()