import json from pathlib import Path from collections import Counter import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import transforms, models from datasets import load_from_disk # --- Fixed inputs for reproducibility and consistent artifact paths --- SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2" ART_DIR = Path("artifacts") ART_DIR.mkdir(parents=True, exist_ok=True) IMG_SIZE = 224 BATCH_SIZE = 16 NUM_WORKERS = 0 # Windows-safe default (avoid multiprocessing issues) EPOCHS = 8 LR = 3e-4 SEED = 42 def set_seed(seed: int): # Ensures consistent shuffling and initialization across runs torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def compute_class_weights(labels, num_classes): # Inverse-frequency weights to reduce bias toward majority classes c = Counter(labels) total = len(labels) weights = [] for k in range(num_classes): freq = c.get(k, 1) / total weights.append(1.0 / freq) w = torch.tensor(weights, dtype=torch.float) # Normalize weights so average weight ≈ 1 (stable loss scale) w = w / w.mean() return w def accuracy(logits, labels): # Simple top-1 accuracy preds = logits.argmax(dim=1) return (preds == labels).float().mean().item() def main(): set_seed(SEED) # --- Device selection (CPU is fine; CUDA if available) --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device:", device) # --- Load the *saved* stratified splits (do NOT re-split each run) --- splits = load_from_disk(SPLIT_DIR) train_ds = splits["train"] val_ds = splits["val"] # --- Label metadata (source of truth for class order) --- label_names = train_ds.features["label"].names num_classes = len(label_names) print("Classes:", label_names) # Save label map alongside the model artifact (needed for inference/API) with open(ART_DIR / "label_names.json", "w", encoding="utf-8") as f: json.dump(label_names, f, ensure_ascii=False, indent=2) # --- Image preprocessing --- # Train: small augmentation (flip) to improve generalization # Val: deterministic transforms only train_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --- HF Dataset -> PyTorch batch conversion (collate_fn) --- # We apply torchvision transforms inside collate_fn because HF stores PIL Images. def transform_batch(examples, tf): images = [tf(img.convert("RGB")) for img in examples["image"]] labels = torch.tensor(examples["label"], dtype=torch.long) return {"pixel_values": torch.stack(images), "labels": labels} def collate_train(batch): imgs = [row["image"] for row in batch] labels = [row["label"] for row in batch] return transform_batch({"image": imgs, "label": labels}, train_tf) def collate_val(batch): imgs = [row["image"] for row in batch] labels = [row["label"] for row in batch] return transform_batch({"image": imgs, "label": labels}, val_tf) # --- DataLoaders (train shuffled, val not shuffled) --- train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate_train) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_val) # --- Transfer Learning model --- # Start from pretrained ImageNet weights, replace final classifier head for 6 classes. USE_PRETRAINED = False weights = models.ResNet18_Weights.DEFAULT if USE_PRETRAINED else None model = models.resnet18(weights=weights) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) model = model.to(device) # --- Loss with class weights (handles mild imbalance) --- class_w = compute_class_weights(train_ds["label"], num_classes).to(device) criterion = nn.CrossEntropyLoss(weight=class_w) # --- Optimizer --- optimizer = torch.optim.AdamW(model.parameters(), lr=LR) # --- Checkpointing: keep the best model by validation accuracy --- best_val_acc = -1.0 best_path = ART_DIR / "model.pt" for epoch in range(1, EPOCHS + 1): # ===== TRAIN LOOP ===== model.train() train_loss = 0.0 train_acc = 0.0 n_train = 0 for batch in train_loader: x = batch["pixel_values"].to(device) y = batch["labels"].to(device) optimizer.zero_grad(set_to_none=True) logits = model(x) loss = criterion(logits, y) loss.backward() optimizer.step() bs = y.size(0) train_loss += loss.item() * bs train_acc += accuracy(logits.detach(), y) * bs n_train += bs train_loss /= n_train train_acc /= n_train # ===== VALIDATION LOOP ===== model.eval() val_loss = 0.0 val_acc = 0.0 n_val = 0 with torch.no_grad(): for batch in val_loader: x = batch["pixel_values"].to(device) y = batch["labels"].to(device) logits = model(x) loss = criterion(logits, y) bs = y.size(0) val_loss += loss.item() * bs val_acc += accuracy(logits, y) * bs n_val += bs val_loss /= n_val val_acc /= n_val print(f"Epoch {epoch:02d}/{EPOCHS} | " f"train loss {train_loss:.4f} acc {train_acc:.4f} | " f"val loss {val_loss:.4f} acc {val_acc:.4f}") # Save best checkpoint if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ "model_state_dict": model.state_dict(), "label_names": label_names, "img_size": IMG_SIZE, "arch": "resnet18", }, best_path) print(f" -> saved best to {best_path} (val_acc={best_val_acc:.4f})") print("\nTraining complete.") print("Best val acc:", best_val_acc) if __name__ == "__main__": main()