--- license: mit base_model: - google/vit-large-patch16-224 pipeline_tag: image-classification datasets: - N-o-1/pokemon-images --- # Pokemon Team Classification with Vision Transformer A fine-tuned Vision Transformer (ViT) model for classifying 6 specific Pokemon from a competitive team setup. This model can identify Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, and Magearna with high accuracy. ## Model Details - **Base Model**: `google/vit-base-patch16-224` - **Model Type**: Vision Transformer for Image Classification - **Classes**: 6 Pokemon (Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, Magearna) - **Input Size**: 224x224 RGB images - **Framework**: PyTorch + Transformers ## Training Details ### Dataset - **Arceus**: 644 images - **Marshadow**: 101 images - **Sandy Shocks**: 75 images - **Slaking**: 152 images - **Reshiram**: 118 images - **Magearna**: 200 images ### Training Strategy - **Balanced Sampling**: Each epoch uses exactly 75 samples per class to prevent overfitting on Arceus - **Data Augmentation**: Random horizontal flip, rotation (±15°), color jitter, and resized crop - **Transfer Learning**: Froze early ViT layers, fine-tuned classifier and later transformer layers - **Early Stopping**: Training stopped when validation loss plateaued (patience=3 epochs) ### Hyperparameters - **Learning Rate**: 2e-5 - **Batch Size**: 16 - **Weight Decay**: 0.01 - **Optimizer**: AdamW - **Epochs**: ~18 (early stopped from max 1000) ## Performance The model achieves excellent classification performance with balanced accuracy across all 6 Pokemon classes despite the imbalanced training dataset. ## Usage ### Basic Classification ```python from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch # Load model and processor model = ViTForImageClassification.from_pretrained("your-username/pokemon-team-vit") processor = ViTImageProcessor.from_pretrained("your-username/pokemon-team-vit") # Load and process image image = Image.open("pokemon_image.jpg") inputs = processor(images=image, return_tensors="pt") # Get predictions with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get results pokemon_names = ["arceus", "marshadow", "sandy-shocks", "slaking", "reshiram", "magearna"] predicted_class = predictions.argmax().item() confidence = predictions.max().item() print(f"Predicted: {pokemon_names[predicted_class]} (confidence: {confidence:.2%})") ``` ### Detailed Probabilities ```python # Get all class probabilities probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] results = {} for idx, pokemon in enumerate(pokemon_names): results[pokemon] = float(probabilities[idx]) # Sort by probability sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True) for pokemon, prob in sorted_results: print(f"{pokemon}: {prob:.1%}") ``` ## Applications - **Pokemon Recognition**: Identify specific Pokemon in images, artwork, or screenshots - **Competitive Team Analysis**: Analyze team compositions in competitive Pokemon content - **Content Moderation**: Filter or categorize Pokemon-related content - **Educational Tools**: Pokemon identification for learning applications ## Limitations - **Specific Pokemon Only**: Only recognizes the 6 trained Pokemon classes - **Image Quality**: Performance may vary with very low resolution or heavily distorted images - **Artistic Variations**: May struggle with highly stylized or non-canonical Pokemon representations - **Background Complexity**: Performance may decrease with very cluttered backgrounds ## Model Architecture The model uses the Vision Transformer (ViT) architecture: - **Patch Size**: 16x16 - **Hidden Size**: 768 - **Attention Heads**: 12 - **Layers**: 12 - **Parameters**: ~86M (base model) + classification head ## Training Infrastructure - **Hardware**: AMD GPU with ROCm support - **Framework**: PyTorch with Transformers library - **Duration**: ~2 minutes per epoch, early stopped at epoch 18 - **Memory**: Optimized for consumer-grade GPU memory ## Citation If you use this model, please cite: ```bibtex @misc{pokemon-team-vit, title={Pokemon Team Classification with Vision Transformer}, author={Steven Van Ingelgem}, year={2025}, url={https://huggingface.co/your-username/pokemon-team-vit} } ``` ## License This model is released under the MIT License. The training data consists of Pokemon images which are © The Pokémon Company/Nintendo. This model is for research and educational purposes. ## Acknowledgments - Base model: Google's Vision Transformer (ViT) - Training framework: Hugging Face Transformers - Pokemon images: Various sources for competitive team analysis --- **Note**: This model is specifically trained for a competitive Pokemon team setup and may not generalize to other Pokemon or use cases. For broader Pokemon classification, consider training on a more comprehensive dataset.