jayebaku's picture
Update app.py
8fca602 verified
raw
history blame
3.37 kB
import os
import time
import gradio as gr
import pandas as pd
from classifier import classify
from statistics import mean
HFTOKEN = os.environ["HF_TOKEN"]
def load_and_analyze_csv(file, text_field, event_model):
df = pd.read_table(file.name)
if text_field not in df.columns:
raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
floods, fires, nones, scores = [], [], [], []
for post in df[text_field].to_list():
res = classify(post, event_model, HFTOKEN)
if res["event"] == 'flood':
floods.append(post)
elif res["event"] == 'fire':
fires.append(post)
else:
nones.append(post)
scores.append(res["score"])
model_confidence = round(mean(scores), 5)
fire_related = gr.CheckboxGroup(choices=fires)
flood_related = gr.CheckboxGroup(choices=floods)
not_related = gr.CheckboxGroup(choices=nones)
return flood_related, fire_related, not_related, model_confidence
def analyze_selected_texts(selections):
selected_texts = selections
analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts]
result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results})
return result_df
with gr.Blocks() as demo:
event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"]
with gr.Tab("Event Type Classification"):
with gr.Row(equal_height=True):
with gr.Column(scale=4):
file_input = gr.File(label="Upload CSV File")
with gr.Column(scale=6):
text_field = gr.Textbox(label="Text field name", value="tweet_text")
event_model = gr.Dropdown(event_models, label="Select classification model")
predict_button = gr.Button("Start Prediction")
with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
with gr.Column():
gr.Markdown("""### Flood-related""")
flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
with gr.Column():
gr.Markdown("""### Fire-related""")
fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
with gr.Column():
gr.Markdown("""### None""")
none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
model_confidence = gr.Number(label="Model Confidence")
predict_button.click(load_and_analyze_csv, inputs=[file_input, text_field, event_model],
outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence])
with gr.Tab("Question Answering"):
# XXX Add some button disabling here, if the classification process is not completed first XXX
analysis_button = gr.Button("Analyze Selected Texts")
analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"])
analysis_button.click(analyze_selected_texts, inputs=flood_checkbox_output, outputs=analysis_output)
demo.launch()