import os import re import torch import gradio as gr from tqdm import tqdm from datasets import load_dataset, DatasetDict from transformers import AutoModelForCausalLM, AutoTokenizer # Automatically detect GPU or use CPU device = "cuda" if torch.cuda.is_available() else "cpu" # Default model path model_tokenizer_path = "zehui127/Omni-DNA-Multitask" # Load tokenizer and model with trusted remote code tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_code=True).to(device) # List of available tasks tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3'] mapping = {'1':'It is a', '0':'It is not a', 'No valid prediction':'Cannot be determined whether or not it is a', } def preprocess_response(response, mask_token="[MASK]"): """Extracts the response after the [MASK] token.""" if mask_token in response: response = response.split(mask_token, 1)[1] response = re.sub(r'^[\sATGC]+', '', response) return response def generate(dna_sequence, task_type, sample_num=1): """ Generates a response based on the DNA sequence and selected task. Args: dna_sequence (str): The input DNA sequence. task_type (str): The selected task type. sample_num (int): Number of samples for the generation process. Returns: str: Predicted function label. """ if task_type is None: task_type = 'H3' dna_sequence = dna_sequence + task_type +"[MASK]" tokenized_message = tokenizer( [dna_sequence], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True ).to(device) response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False) reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "") pred = extract_label(reply, task_type) return f"{mapping[pred]} {task_type}" def extract_label(message, task_type): """Extracts the prediction label from the model's response.""" task_type = '[MASK]' answer = message.split(task_type)[1] match = re.search(r'\d+', answer) return match.group() if match else "No valid prediction" # Gradio interface interface = gr.Interface( fn=generate, inputs=[ gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"), gr.Dropdown(choices=tasks, label="Select Task Type"), ], outputs=gr.Textbox(label="Predicted Type"), title="Omni-DNA Multitask Prediction", description="Select a DNA-related task and input a sequence to generate function predictions.", ) if __name__ == "__main__": interface.launch()