|
|
| import random |
| import numpy as np |
| import torch |
| from PIL import Image |
| Image.MAX_IMAGE_PIXELS = None |
| from transformers import AutoModel, AutoTokenizer |
| import opencc |
| from ultralytics import YOLO |
| from config.configu import * |
| from utils.utils import * |
| import logging |
| import argparse |
|
|
| def setup_logger(log_file): |
| logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s') |
| logger = logging.getLogger() |
| return logger |
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| cc = opencc.OpenCC('t2s.json') |
| set_seed(SEED) |
| converter_t2s = opencc.OpenCC('t2s') |
|
|
|
|
| def single_rec(model,tokenizer,detect_model,generation_config,image_path,prompt,use_p,hard_vq,drop_zero,repetition_penalty,verbose): |
| response, history = model.chat_ocr(tokenizer, detect_model,image_path, prompt, generation_config, |
| use_p=use_p, |
| hard_vq=hard_vq, |
| drop_zero=drop_zero,repetition_penalty=repetition_penalty,return_history=True,verbose=verbose) |
| print(f'User: {prompt}\nAssistant: {response}') |
|
|
| def folder_rec(model,tokenizer,detect_model,generation_config,folder_path,prompt,save_name,use_p,hard_vq,drop_zero,repetition_penalty,verbose): |
| results=[] |
|
|
| all_images=get_image_paths(folder_path) |
| for pic in tqdm(all_images): |
| pic_path=os.path.join(folder_path,pic) |
| try: |
| response, history = model.chat_ocr(tokenizer, detect_model,pic_path, prompt, generation_config, |
| use_p=use_p, |
| hard_vq=hard_vq, |
| drop_zero=drop_zero,repetition_penalty=repetition_penalty,return_history=True,verbose=verbose) |
| except Exception as e: |
| print(f"An error has occured:\n{e}") |
| response="ERROR!" |
| print(f'User: {prompt}\nAssistant: {response}') |
| results.append({"imagePath":pic_path,'prompt':prompt,'response':response}) |
| if not save_name.endswith('json'): |
| save_name+='_result.json' |
| save_json(save_name,results) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="args for inference task") |
|
|
| parser.add_argument('--tgt', type=str,help='Recognition target') |
| parser.add_argument('--prompt', type=str,default='这幅书法作品内容是什么?',help='Prompt for recognition') |
| parser.add_argument('--save_name',type=str,default="recognition.json",help="Storage of results if multiple images recognition mode") |
|
|
| parser.add_argument('--use_p', type=bool, default=True,help='Decide the usage of perceiver resampler') |
| parser.add_argument('--hard_vq', type=bool, default=False,help='Decide the usage of closest similarity match') |
| parser.add_argument('--drop_zero', type=bool, default=False,help='Decide the deletion of zero padding in pseudo tokens') |
| parser.add_argument('--verbose', type=bool, default=False,help='Decide the output of extra information') |
| parser.add_argument('--repetition_penalty', type=float, default=1.0,help='Repetition penalty for generation') |
| |
|
|
| args = parser.parse_args() |
|
|
| if not isinstance(args.tgt,str): |
| raise ValueError(f"The target should a string, not a instance of {type(args.tgt)}!") |
|
|
| |
| model = AutoModel.from_pretrained( |
| INTERNVL_PATH, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True).eval().cuda() |
| tokenizer = AutoTokenizer.from_pretrained(INTERNVL_PATH, trust_remote_code=True) |
|
|
| generation_config = dict( |
| num_beams=1, |
| max_new_tokens=1024, |
| do_sample=False, |
| ) |
|
|
| detect_model=YOLO(YOLO_CHECKPOINT) |
| if is_image(args.tgt): |
| print("Single image recognition mode.") |
| single_rec( |
| model, |
| tokenizer, |
| detect_model, |
| generation_config, |
| args.tgt, |
| args.prompt, |
| args.use_p, |
| args.hard_vq, |
| args.drop_zero, |
| args.repetition_penalty, |
| args.verbose) |
| elif os.path.isdir(args.tgt): |
| print("Multiple images recognition mode") |
| os.makedirs('results',exist_ok=True) |
| folder_rec( |
| model, |
| tokenizer, |
| detect_model, |
| generation_config, |
| args.tgt, |
| args.prompt, |
| os.path.join('results',args.save_name), |
| args.use_p, |
| args.hard_vq, |
| args.drop_zero, |
| args.repetition_penalty, |
| args.verbose) |
| else: |
| raise ValueError(f"The target should be either a image path or a folder that contain images!") |
|
|
| def single_image_wrapped(image,prompts): |
| model = AutoModel.from_pretrained( |
| INTERNVL_PATH, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True).eval().cuda() |
| tokenizer = AutoTokenizer.from_pretrained(INTERNVL_PATH, trust_remote_code=True) |
|
|
| generation_config = dict( |
| num_beams=1, |
| max_new_tokens=1024, |
| do_sample=False, |
| ) |
| detect_model=YOLO(YOLO_CHECKPOINT) |
|
|
|
|
| temp_dir = "temp_images" |
| os.makedirs(temp_dir, exist_ok=True) |
| |
| |
| temp_image_path = os.path.join(temp_dir, "uploaded_image.png") |
| image.save(temp_image_path) |
| single_rec( |
| model, |
| tokenizer, |
| detect_model, |
| generation_config, |
| temp_image_path, |
| prompts, |
| True, |
| False, |
| True, |
| 1.2, |
| False) |
| if __name__=='__main__': |
|
|
| main() |
|
|
|
|
|
|