page_position_binary_classif / generate_embeddings.py
toni99c's picture
Update generate_embeddings.py
4a56eda verified
# Given a DataFrame tasks_df with 'image_path' col that contains all images paths this script will produce a 'image_embeddings.pickle'
# file that contains all the embeddings. You can stop and resume whenever you want, it will restart from the last saved image file
#
import pandas as pd
import sys
from transformers import pipeline
import torch
from transformers import AutoModel, AutoProcessor
from transformers.image_utils import load_image
ckpt = "google/siglip2-so400m-patch16-512"
model = AutoModel.from_pretrained(ckpt, device_map="auto").eval()
processor = AutoProcessor.from_pretrained(ckpt)
tasks_df = # load DataFrame with 'image_path' col that contains all images paths
save_interval = 100 # save embeddings file every save_interval images
try:
embeddings_df = pd.read_pickle('image_embeddings.pickle')
index = embeddings_df.shape[0]
except:
index=0
embeddings_df = pd.DataFrame(columns=['image_embedding'])
while index<tasks_df.shape[0]:
image = load_image(tasks_df['image_path'][index])
inputs = processor(images=[image], return_tensors="pt").to(model.device)
with torch.no_grad():
image_embeddings = model.get_image_features(**inputs)
new_row = {'image_embedding': image_embeddings}
embeddings_df = pd.concat([embeddings_df, pd.DataFrame([new_row])], ignore_index=True)
if index % save_interval==0:
embeddings_df.to_pickle('image_embeddings.pickle')
index+=1