Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| from torchvision import transforms, models | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import pandas as pd | |
| from torch.utils.data import Dataset | |
| import torch.nn as nn | |
| import urllib.parse | |
| import re | |
| # Set device | |
| if torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| print("Utilizzo del dispositivo MPS") | |
| else: | |
| device = torch.device("cpu") | |
| print("Utilizzo del dispositivo CPU") | |
| # Dataset class | |
| class ArtDataset(Dataset): | |
| def __init__(self, csv_file, transform=None): | |
| self.annotations = pd.read_csv(csv_file, delimiter=";") | |
| self.transform = transform | |
| self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())} | |
| self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())} | |
| def __len__(self): | |
| return len(self.annotations) | |
| def __getitem__(self, idx): | |
| img_path = self.annotations.iloc[idx]['filename'] | |
| safe_img_path = urllib.parse.quote(img_path, safe="/:") | |
| try: | |
| image = Image.open(safe_img_path).convert("RGB") | |
| style_label = self.label_map_style[self.annotations.iloc[idx]['genre']] | |
| artist_label = self.label_map_artist[self.annotations.iloc[idx]['artist']] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, (style_label, artist_label) | |
| except (FileNotFoundError, OSError): | |
| return None, (None, None) | |
| # Image transformations | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Load dataset | |
| csv_file = "classes.csv" | |
| dataset = ArtDataset(csv_file=csv_file, transform=data_transforms) | |
| # Define model | |
| class DualOutputResNet(nn.Module): | |
| def __init__(self, num_styles, num_artists): | |
| super(DualOutputResNet, self).__init__() | |
| self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
| num_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() | |
| self.fc_style = nn.Linear(num_features, num_styles) | |
| self.fc_artist = nn.Linear(num_features, num_artists) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| style_output = self.fc_style(features) | |
| artist_output = self.fc_artist(features) | |
| return style_output, artist_output | |
| # Load pre-trained model | |
| num_styles = len(dataset.label_map_style) | |
| num_artists = len(dataset.label_map_artist) | |
| model = DualOutputResNet(num_styles, num_artists).to(device) | |
| model.load_state_dict(torch.load("dual_output_resnet.pth", map_location=device)) | |
| model.eval() | |
| # Load CLIP model | |
| model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) | |
| model_clip.eval() | |
| # Load GPT-Neo model | |
| model_name = "EleutherAI/gpt-neo-1.3B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| #Load dataset | |
| dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description']) | |
| dataset_desc.columns = dataset_desc.columns.str.lower() | |
| style_desc = pd.read_csv("style_desc.csv", delimiter=';') | |
| style_desc.columns = style_desc.columns.str.lower() | |
| # Function to enrich prompt | |
| def enrich_prompt(artist, style): | |
| artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values | |
| style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values | |
| if len(style_info) == 0: | |
| style_keywords = style.lower().split() | |
| for keyword in style_keywords: | |
| safe_keyword = re.escape(keyword) | |
| partial_matches = style_desc[style_desc['style'].str.lower().str.contains(safe_keyword, na=False, regex=True)] | |
| if not partial_matches.empty: | |
| style_info = partial_matches['description'].values | |
| break | |
| artist_details = artist_info[0] if len(artist_info) > 0 else "" | |
| style_details = style_info[0] if len(style_info) > 0 else "" | |
| return f"{artist_details} This work exemplifies {style_details}." | |
| # Function to generate description | |
| def generate_description(image_path): | |
| image = Image.open(image_path).convert("RGB") | |
| image_resnet = data_transforms(image).unsqueeze(0).to(device) | |
| # Predict style and artist | |
| with torch.no_grad(): | |
| outputs_style, outputs_artist = model(image_resnet) | |
| _, predicted_style_idx = torch.max(outputs_style, 1) | |
| _, predicted_artist_idx = torch.max(outputs_artist, 1) | |
| idx_to_style = {v: k for k, v in dataset.label_map_style.items()} | |
| idx_to_artist = {v: k for k, v in dataset.label_map_artist.items()} | |
| predicted_style = idx_to_style[predicted_style_idx.item()] | |
| predicted_artist = idx_to_artist[predicted_artist_idx.item()] | |
| # Enrich prompt | |
| enriched_prompt = enrich_prompt(predicted_artist, predicted_style) | |
| full_prompt = ( | |
| f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} " | |
| "Describe its distinctive features, considering both the artist's techniques and the artistic style." | |
| ) | |
| input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device) | |
| output = model_gptneo.generate( | |
| input_ids=input_ids, | |
| max_length=250, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.2 | |
| ) | |
| description_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return predicted_style, predicted_artist, description_text | |
| # Gradio interface | |
| def predict(image): | |
| style, artist, description = generate_description(image) | |
| return f"**Predicted Style**: {style}\n\n**Predicted Artist**: {artist}\n\n**Description**:\n{description}" | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="filepath"), | |
| outputs="text", | |
| title="AI-Powered Artwork Recognition and Description", | |
| description="Upload an image of artwork to predict its style and artist, and generate a description." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |