Spaces:
Sleeping
Sleeping
| import json | |
| import sqlite3 | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| from PIL import Image | |
| import clip | |
| import faiss | |
| import numpy as np | |
| import glob | |
| # Đường dẫn lưu trữ | |
| VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db' | |
| IMAGE_FAISS_INDEX_PATH = 'app/data/image_faiss_index.index' | |
| TEXT_FAISS_INDEX_PATH = 'app/data/text_faiss_index.index' | |
| # Đường dẫn dữ liệu | |
| DATA_ROOT = '/Users/artteiv/Desktop/Graduated/chore-graduated/Data' | |
| MAIN_DATA_PATH = os.path.join(DATA_ROOT, 'main_data') | |
| CAPTIONS_PATH = os.path.join(DATA_ROOT, 'captions') | |
| # Kết nối SQLite | |
| conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH) | |
| cursor = conn.cursor() | |
| # Tạo bảng embeddings cho ảnh và văn bản | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS image_embeddings ( | |
| e_index INTEGER PRIMARY KEY, | |
| image_path TEXT NOT NULL, | |
| caption TEXT NOT NULL, | |
| category TEXT NOT NULL, | |
| subcategory TEXT NOT NULL | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS text_embeddings ( | |
| e_index INTEGER PRIMARY KEY, | |
| text TEXT NOT NULL, | |
| category TEXT NOT NULL, | |
| subcategory TEXT NOT NULL | |
| ) | |
| ''') | |
| def insert_image_embedding(e_index, image_path, caption, category, subcategory): | |
| """Thêm embedding ảnh vào SQLite.""" | |
| cursor.execute(''' | |
| INSERT INTO image_embeddings (e_index, image_path, caption, category, subcategory) | |
| VALUES (?, ?, ?, ?, ?) | |
| ''', (e_index, image_path, caption, category, subcategory)) | |
| conn.commit() | |
| print(f"Đã thêm embedding ảnh: {image_path}") | |
| def insert_text_embedding(e_index, text, category, subcategory): | |
| """Thêm embedding văn bản vào SQLite.""" | |
| cursor.execute(''' | |
| INSERT INTO text_embeddings (e_index, text, category, subcategory) | |
| VALUES (?, ?, ?, ?) | |
| ''', (e_index, text, category, subcategory)) | |
| conn.commit() | |
| print(f"Đã thêm embedding văn bản: {text[:50]}...") | |
| def save_faiss_index(index, index_file): | |
| """Lưu FAISS index vào file.""" | |
| faiss.write_index(index, index_file) | |
| print(f"Đã lưu FAISS index vào {index_file}") | |
| def load_faiss_index(index_file): | |
| """Nạp FAISS index từ file.""" | |
| if os.path.exists(index_file): | |
| index = faiss.read_index(index_file) | |
| print(f"Đã nạp FAISS index từ {index_file}") | |
| return index | |
| return None | |
| def compute_embeddings(): | |
| """Tính toán embeddings cho ảnh và văn bản sử dụng CLIP.""" | |
| print("Loading CLIP model...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| print("Model loaded") | |
| # Lấy danh sách các thư mục con (categories) | |
| categories = [d for d in os.listdir(MAIN_DATA_PATH) if os.path.isdir(os.path.join(MAIN_DATA_PATH, d))] | |
| image_paths = [] | |
| captions = [] | |
| texts = [] | |
| categories_list = [] | |
| subcategories_list = [] | |
| # Chuẩn bị dữ liệu | |
| print("Processing data from directories...") | |
| for category in categories: | |
| # Đường dẫn đến thư mục category | |
| category_path = os.path.join(MAIN_DATA_PATH, category) | |
| # Lấy danh sách các subcategories | |
| subcategories = [d for d in os.listdir(category_path) if os.path.isdir(os.path.join(category_path, d))] | |
| for subcategory in subcategories: | |
| # Đường dẫn đến thư mục ảnh và caption của subcategory | |
| subcategory_image_path = os.path.join(category_path, subcategory) | |
| subcategory_caption_path = os.path.join(CAPTIONS_PATH, category, subcategory) | |
| # Lấy danh sách ảnh | |
| image_files = glob.glob(os.path.join(subcategory_image_path, '*.*')) | |
| for img_path in image_files: | |
| # Lấy tên file không có phần mở rộng | |
| base_name = os.path.splitext(os.path.basename(img_path))[0] | |
| caption_file = os.path.join(subcategory_caption_path, f"{base_name}.txt") | |
| if os.path.exists(caption_file): | |
| try: | |
| # Đọc caption | |
| with open(caption_file, 'r', encoding='utf-8') as f: | |
| caption = f.read().strip() | |
| # Thêm vào danh sách | |
| image_paths.append(img_path) | |
| captions.append(caption) | |
| texts.append(caption) # Sử dụng caption làm text | |
| categories_list.append(category) | |
| subcategories_list.append(subcategory) | |
| except Exception as e: | |
| print(f"Error processing {img_path}: {e}") | |
| continue | |
| # Tính toán embeddings cho ảnh | |
| # if image_paths: | |
| # print("Computing image embeddings...") | |
| # image_embeddings = [] | |
| # for idx, img_path in enumerate(image_paths): | |
| # try: | |
| # image = preprocess(Image.open(img_path)).unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # image_features = model.encode_image(image) | |
| # image_features = image_features.cpu().numpy() | |
| # faiss.normalize_L2(image_features) | |
| # image_embeddings.append(image_features[0]) | |
| # insert_image_embedding(idx, img_path, captions[idx], categories_list[idx], subcategories_list[idx]) | |
| # except Exception as e: | |
| # print(f"Error processing image {img_path}: {e}") | |
| # continue | |
| # if image_embeddings: | |
| # image_embeddings = np.array(image_embeddings) | |
| # d = image_embeddings.shape[1] | |
| # image_index = faiss.IndexFlatIP(d) | |
| # image_index.add(image_embeddings) | |
| # save_faiss_index(image_index, IMAGE_FAISS_INDEX_PATH) | |
| # Tính toán embeddings cho văn bản | |
| if texts: | |
| print("Computing text embeddings...") | |
| text_tokens = clip.tokenize(texts, truncate=True).to(device) | |
| print("Kích thước của text_tokens:", text_tokens.shape) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(text_tokens) | |
| text_features = text_features.cpu().numpy() | |
| faiss.normalize_L2(text_features) | |
| d = text_features.shape[1] | |
| text_index = faiss.IndexFlatIP(d) | |
| text_index.add(text_features) | |
| # Lưu text embeddings vào SQLite | |
| for idx, (text, category, subcategory) in enumerate(zip(texts, categories_list, subcategories_list)): | |
| insert_text_embedding(idx, text, category, subcategory) | |
| save_faiss_index(text_index, TEXT_FAISS_INDEX_PATH) | |
| print("Processing completed") | |
| return image_index if image_paths else None, text_index if texts else None | |
| def predict_image(image_path): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| image = preprocess(Image.open(image_path)).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image) | |
| image_features = image_features.cpu().numpy() | |
| faiss.normalize_L2(image_features) | |
| index = load_faiss_index(IMAGE_FAISS_INDEX_PATH) | |
| distances, indices = index.search(image_features, k=10) | |
| return distances, indices | |
| if __name__ == '__main__': | |
| ## Predict | |
| try: | |
| image_index, text_index = compute_embeddings() | |
| if image_index: | |
| print(f"Image index ready with {image_index.ntotal} embeddings") | |
| if text_index: | |
| print(f"Text index ready with {text_index.ntotal} embeddings") | |
| finally: | |
| conn.close() | |
| print("SQLite connection closed") | |