Pokemon Team Classification with Vision Transformer

A fine-tuned Vision Transformer (ViT) model for classifying 6 specific Pokemon from a competitive team setup. This model can identify Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, and Magearna with high accuracy.

Model Details

  • Base Model: google/vit-base-patch16-224
  • Model Type: Vision Transformer for Image Classification
  • Classes: 6 Pokemon (Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, Magearna)
  • Input Size: 224x224 RGB images
  • Framework: PyTorch + Transformers

Training Details

Dataset

  • Arceus: 644 images
  • Marshadow: 101 images
  • Sandy Shocks: 75 images
  • Slaking: 152 images
  • Reshiram: 118 images
  • Magearna: 200 images

Training Strategy

  • Balanced Sampling: Each epoch uses exactly 75 samples per class to prevent overfitting on Arceus
  • Data Augmentation: Random horizontal flip, rotation (±15°), color jitter, and resized crop
  • Transfer Learning: Froze early ViT layers, fine-tuned classifier and later transformer layers
  • Early Stopping: Training stopped when validation loss plateaued (patience=3 epochs)

Hyperparameters

  • Learning Rate: 2e-5
  • Batch Size: 16
  • Weight Decay: 0.01
  • Optimizer: AdamW
  • Epochs: ~18 (early stopped from max 1000)

Performance

The model achieves excellent classification performance with balanced accuracy across all 6 Pokemon classes despite the imbalanced training dataset.

Usage

Basic Classification

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch

# Load model and processor
model = ViTForImageClassification.from_pretrained("your-username/pokemon-team-vit")
processor = ViTImageProcessor.from_pretrained("your-username/pokemon-team-vit")

# Load and process image
image = Image.open("pokemon_image.jpg")
inputs = processor(images=image, return_tensors="pt")

# Get predictions
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

# Get results
pokemon_names = ["arceus", "marshadow", "sandy-shocks", "slaking", "reshiram", "magearna"]
predicted_class = predictions.argmax().item()
confidence = predictions.max().item()

print(f"Predicted: {pokemon_names[predicted_class]} (confidence: {confidence:.2%})")

Detailed Probabilities

# Get all class probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]

results = {}
for idx, pokemon in enumerate(pokemon_names):
    results[pokemon] = float(probabilities[idx])

# Sort by probability
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
for pokemon, prob in sorted_results:
    print(f"{pokemon}: {prob:.1%}")

Applications

  • Pokemon Recognition: Identify specific Pokemon in images, artwork, or screenshots
  • Competitive Team Analysis: Analyze team compositions in competitive Pokemon content
  • Content Moderation: Filter or categorize Pokemon-related content
  • Educational Tools: Pokemon identification for learning applications

Limitations

  • Specific Pokemon Only: Only recognizes the 6 trained Pokemon classes
  • Image Quality: Performance may vary with very low resolution or heavily distorted images
  • Artistic Variations: May struggle with highly stylized or non-canonical Pokemon representations
  • Background Complexity: Performance may decrease with very cluttered backgrounds

Model Architecture

The model uses the Vision Transformer (ViT) architecture:

  • Patch Size: 16x16
  • Hidden Size: 768
  • Attention Heads: 12
  • Layers: 12
  • Parameters: ~86M (base model) + classification head

Training Infrastructure

  • Hardware: AMD GPU with ROCm support
  • Framework: PyTorch with Transformers library
  • Duration: ~2 minutes per epoch, early stopped at epoch 18
  • Memory: Optimized for consumer-grade GPU memory

Citation

If you use this model, please cite:

@misc{pokemon-team-vit,
  title={Pokemon Team Classification with Vision Transformer},
  author={Steven Van Ingelgem},
  year={2025},
  url={https://huggingface.co/your-username/pokemon-team-vit}
}

License

This model is released under the MIT License. The training data consists of Pokemon images which are © The Pokémon Company/Nintendo. This model is for research and educational purposes.

Acknowledgments

  • Base model: Google's Vision Transformer (ViT)
  • Training framework: Hugging Face Transformers
  • Pokemon images: Various sources for competitive team analysis

Note: This model is specifically trained for a competitive Pokemon team setup and may not generalize to other Pokemon or use cases. For broader Pokemon classification, consider training on a more comprehensive dataset.

Downloads last month
8
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for N-o-1/pokemon-classifier

Finetuned
(103)
this model

Dataset used to train N-o-1/pokemon-classifier