Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import torch | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from datasets import load_dataset, DatasetDict | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Automatically detect GPU or use CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Default model path | |
| model_tokenizer_path = "zehui127/Omni-DNA-Multitask" | |
| # Load tokenizer and model with trusted remote code | |
| tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_code=True).to(device) | |
| # List of available tasks | |
| tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3'] | |
| mapping = {'1':'It is a', | |
| '0':'It is not a', | |
| 'No valid prediction':'Cannot be determined whether or not it is a', | |
| } | |
| def preprocess_response(response, mask_token="[MASK]"): | |
| """Extracts the response after the [MASK] token.""" | |
| if mask_token in response: | |
| response = response.split(mask_token, 1)[1] | |
| response = re.sub(r'^[\sATGC]+', '', response) | |
| return response | |
| def generate(dna_sequence, task_type, sample_num=1): | |
| """ | |
| Generates a response based on the DNA sequence and selected task. | |
| Args: | |
| dna_sequence (str): The input DNA sequence. | |
| task_type (str): The selected task type. | |
| sample_num (int): Number of samples for the generation process. | |
| Returns: | |
| str: Predicted function label. | |
| """ | |
| if task_type is None: | |
| task_type = 'H3' | |
| dna_sequence = dna_sequence + task_type +"[MASK]" | |
| tokenized_message = tokenizer( | |
| [dna_sequence], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True | |
| ).to(device) | |
| response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False) | |
| reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "") | |
| pred = extract_label(reply, task_type) | |
| return f"{mapping[pred]} {task_type}" | |
| def extract_label(message, task_type): | |
| """Extracts the prediction label from the model's response.""" | |
| task_type = '[MASK]' | |
| answer = message.split(task_type)[1] | |
| match = re.search(r'\d+', answer) | |
| return match.group() if match else "No valid prediction" | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=generate, | |
| inputs=[ | |
| gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"), | |
| gr.Dropdown(choices=tasks, label="Select Task Type"), | |
| ], | |
| outputs=gr.Textbox(label="Predicted Type"), | |
| title="Omni-DNA Multitask Prediction", | |
| description="Select a DNA-related task and input a sequence to generate function predictions.", | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |