Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,182 +7,197 @@ import onnxruntime as ort
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
import gradio as gr
|
| 9 |
|
|
|
|
| 10 |
class ONNXInferencePipeline:
|
| 11 |
-
def __init__(self, repo_id):
|
| 12 |
-
#
|
| 13 |
hf_token = os.getenv("HF_TOKEN")
|
| 14 |
-
if hf_token is None:
|
| 15 |
-
raise ValueError("HF_TOKEN environment variable is not set.")
|
| 16 |
|
| 17 |
# Load banned keywords list
|
| 18 |
self.banned_keywords = self.load_banned_keywords()
|
| 19 |
print(f"Loaded {len(self.banned_keywords)} banned keywords")
|
| 20 |
-
|
| 21 |
-
# Download
|
| 22 |
-
self.onnx_path = hf_hub_download(
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Load configuration
|
| 27 |
-
with open(self.config_path) as f:
|
| 28 |
self.config = json.load(f)
|
| 29 |
|
| 30 |
# Initialize tokenizer
|
| 31 |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
|
| 32 |
-
self.max_len = 256
|
| 33 |
|
| 34 |
# Initialize ONNX runtime session
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
if
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def load_banned_keywords(self):
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
code_str = os.getenv("banned")
|
| 46 |
if not code_str:
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
local_vars = {}
|
| 51 |
-
|
| 52 |
-
# Wrap the code in a function to allow return statements
|
| 53 |
wrapped_code = f"""
|
| 54 |
def get_banned_keywords():
|
| 55 |
{textwrap.indent(code_str, ' ')}
|
| 56 |
"""
|
| 57 |
-
|
| 58 |
try:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
except Exception as e:
|
| 64 |
-
print(f"Error loading banned keywords: {e}")
|
| 65 |
-
# Return a default empty list if there's an error
|
| 66 |
return []
|
| 67 |
|
| 68 |
def contains_banned_keyword(self, text):
|
| 69 |
-
"""Check if the input text contains any banned keywords."""
|
| 70 |
text_lower = text.lower()
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
for keyword in self.banned_keywords:
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
if
|
| 80 |
print(f"Keyword detected: '{keyword}'")
|
| 81 |
return True
|
| 82 |
-
|
| 83 |
-
print("Keywords Passed - No inappropriate keywords found")
|
| 84 |
return False
|
| 85 |
|
| 86 |
def preprocess(self, text):
|
| 87 |
encoding = self.tokenizer.encode(text)
|
| 88 |
-
ids = encoding.ids[:self.max_len]
|
| 89 |
padding = [0] * (self.max_len - len(ids))
|
| 90 |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def predict(self, text):
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
if self.contains_banned_keyword(text):
|
| 97 |
print("Input rejected by keyword filter")
|
| 98 |
return {
|
| 99 |
-
|
| 100 |
-
|
| 101 |
}
|
| 102 |
-
|
| 103 |
-
# If no banned keywords found, proceed with model prediction
|
| 104 |
-
print("Running ML model for classification...")
|
| 105 |
-
|
| 106 |
# Preprocess
|
| 107 |
input_array = self.preprocess(text)
|
| 108 |
|
| 109 |
-
# Run inference
|
| 110 |
-
|
| 111 |
-
None,
|
| 112 |
-
{'input': input_array}
|
| 113 |
-
)
|
| 114 |
|
| 115 |
-
# Post
|
| 116 |
-
logits =
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# Log model result
|
| 124 |
-
print(f"Model Passed - Result: {class_labels[predicted_class]} (Confidence: {probabilities[0][predicted_class]:.2%})")
|
| 125 |
-
|
| 126 |
-
return {
|
| 127 |
-
'label': class_labels[predicted_class],
|
| 128 |
-
'probabilities': probabilities[0].tolist()
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
# Example usage
|
| 132 |
-
if __name__ == "__main__":
|
| 133 |
-
# Initialize the pipeline with the Hugging Face repository ID
|
| 134 |
-
print("Initializing content filter pipeline...")
|
| 135 |
-
pipeline = ONNXInferencePipeline(repo_id="iimran/abuse-detector")
|
| 136 |
-
print("Pipeline initialized successfully")
|
| 137 |
-
|
| 138 |
-
# Example texts for testing
|
| 139 |
-
example_texts = [
|
| 140 |
-
"You're a worthless piece of garbage who should die",
|
| 141 |
-
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team."
|
| 142 |
-
]
|
| 143 |
-
|
| 144 |
-
for text in example_texts:
|
| 145 |
-
result = pipeline.predict(text)
|
| 146 |
-
print(f"Input: {text[:50]}...")
|
| 147 |
-
print(f"Prediction: {result['label']} ")
|
| 148 |
-
print("-" * 80)
|
| 149 |
-
|
| 150 |
-
# Define a function for Gradio to use
|
| 151 |
-
def gradio_predict(text):
|
| 152 |
-
result = pipeline.predict(text)
|
| 153 |
-
return (
|
| 154 |
-
f"Prediction: {result['label']} \n"
|
| 155 |
-
)
|
| 156 |
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
iface = gr.Interface(
|
| 159 |
fn=gradio_predict,
|
| 160 |
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
|
| 161 |
outputs="text",
|
| 162 |
title="Abuse Detector - Offensive Language Detector",
|
| 163 |
description=(
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
examples=[
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
"I appreciate your help, but honestly, you're such a clueless idiot sometimes. Still, thanks for trying."
|
| 182 |
-
]
|
| 183 |
)
|
| 184 |
-
|
| 185 |
-
# Launch the Gradio app
|
| 186 |
print("Launching Gradio interface...")
|
| 187 |
-
# at the very bottom where you launch Gradio
|
| 188 |
iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|
|
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
import gradio as gr
|
| 9 |
|
| 10 |
+
|
| 11 |
class ONNXInferencePipeline:
|
| 12 |
+
def __init__(self, repo_id, repo_type="model"):
|
| 13 |
+
# Read token from env. In a Space, HF_TOKEN can be set in the Secrets panel.
|
| 14 |
hf_token = os.getenv("HF_TOKEN")
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# Load banned keywords list
|
| 17 |
self.banned_keywords = self.load_banned_keywords()
|
| 18 |
print(f"Loaded {len(self.banned_keywords)} banned keywords")
|
| 19 |
+
|
| 20 |
+
# Download artifacts. Newer huggingface_hub uses token=, not use_auth_token=
|
| 21 |
+
self.onnx_path = hf_hub_download(
|
| 22 |
+
repo_id=repo_id,
|
| 23 |
+
filename="model.onnx",
|
| 24 |
+
token=hf_token,
|
| 25 |
+
repo_type=repo_type
|
| 26 |
+
)
|
| 27 |
+
self.tokenizer_path = hf_hub_download(
|
| 28 |
+
repo_id=repo_id,
|
| 29 |
+
filename="train_bpe_tokenizer.json",
|
| 30 |
+
token=hf_token,
|
| 31 |
+
repo_type=repo_type
|
| 32 |
+
)
|
| 33 |
+
self.config_path = hf_hub_download(
|
| 34 |
+
repo_id=repo_id,
|
| 35 |
+
filename="hyperparameters.json",
|
| 36 |
+
token=hf_token,
|
| 37 |
+
repo_type=repo_type
|
| 38 |
+
)
|
| 39 |
|
| 40 |
# Load configuration
|
| 41 |
+
with open(self.config_path, "r") as f:
|
| 42 |
self.config = json.load(f)
|
| 43 |
|
| 44 |
# Initialize tokenizer
|
| 45 |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
|
| 46 |
+
self.max_len = int(self.config.get("max_len", 256))
|
| 47 |
|
| 48 |
# Initialize ONNX runtime session
|
| 49 |
+
# Spaces CPU runtime typically uses CPUExecutionProvider
|
| 50 |
+
providers = ort.get_available_providers()
|
| 51 |
+
if "CUDAExecutionProvider" in providers:
|
| 52 |
+
use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 53 |
+
else:
|
| 54 |
+
use_providers = ["CPUExecutionProvider"]
|
| 55 |
+
|
| 56 |
+
sess_options = ort.SessionOptions()
|
| 57 |
+
# Reduce memory and improve cold start a bit
|
| 58 |
+
sess_options.enable_mem_pattern = False
|
| 59 |
+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 60 |
+
|
| 61 |
+
self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options, providers=use_providers)
|
| 62 |
+
|
| 63 |
+
# Cache model input name to avoid mismatches like input vs input_ids
|
| 64 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 65 |
+
print(f"ONNX model input name detected: {self.input_name}")
|
| 66 |
+
|
| 67 |
+
# If you want label order from config, you can read it
|
| 68 |
+
self.class_labels = self.config.get("class_labels", ["Inappropriate Content", "Appropriate"])
|
| 69 |
|
| 70 |
def load_banned_keywords(self):
|
| 71 |
+
"""
|
| 72 |
+
Load banned keywords from env var named 'banned'.
|
| 73 |
+
Supports two formats:
|
| 74 |
+
1) Python code snippet that returns a list (your current method)
|
| 75 |
+
2) JSON array of strings
|
| 76 |
+
"""
|
| 77 |
code_str = os.getenv("banned")
|
| 78 |
if not code_str:
|
| 79 |
+
print("Environment variable 'banned' is not set. Using empty list.")
|
| 80 |
+
return []
|
| 81 |
|
| 82 |
+
# Try JSON first
|
| 83 |
+
try:
|
| 84 |
+
parsed = json.loads(code_str)
|
| 85 |
+
if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
|
| 86 |
+
return parsed
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
# Fallback to executable code that returns a list
|
| 91 |
local_vars = {}
|
|
|
|
|
|
|
| 92 |
wrapped_code = f"""
|
| 93 |
def get_banned_keywords():
|
| 94 |
{textwrap.indent(code_str, ' ')}
|
| 95 |
"""
|
|
|
|
| 96 |
try:
|
| 97 |
+
exec(wrapped_code, {}, local_vars)
|
| 98 |
+
result = local_vars["get_banned_keywords"]()
|
| 99 |
+
if isinstance(result, list):
|
| 100 |
+
return [str(x) for x in result]
|
| 101 |
+
print("Loaded banned keywords code did not return a list. Using empty list.")
|
| 102 |
+
return []
|
| 103 |
except Exception as e:
|
| 104 |
+
print(f"Error loading banned keywords from code: {e}")
|
|
|
|
| 105 |
return []
|
| 106 |
|
| 107 |
def contains_banned_keyword(self, text):
|
| 108 |
+
"""Check if the input text contains any banned keywords as whole words."""
|
| 109 |
text_lower = text.lower()
|
| 110 |
+
words = "".join(c if c.isalnum() else " " for c in text_lower).split()
|
| 111 |
+
word_set = set(words)
|
| 112 |
+
|
|
|
|
| 113 |
for keyword in self.banned_keywords:
|
| 114 |
+
kw = str(keyword).lower().strip()
|
| 115 |
+
if not kw:
|
| 116 |
+
continue
|
| 117 |
+
if kw in word_set:
|
| 118 |
print(f"Keyword detected: '{keyword}'")
|
| 119 |
return True
|
| 120 |
+
print("Keywords Passed. No inappropriate keywords found")
|
|
|
|
| 121 |
return False
|
| 122 |
|
| 123 |
def preprocess(self, text):
|
| 124 |
encoding = self.tokenizer.encode(text)
|
| 125 |
+
ids = encoding.ids[: self.max_len]
|
| 126 |
padding = [0] * (self.max_len - len(ids))
|
| 127 |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
|
| 128 |
|
| 129 |
+
@staticmethod
|
| 130 |
+
def softmax(logits):
|
| 131 |
+
# Numerically stable softmax
|
| 132 |
+
x = logits - np.max(logits, axis=1, keepdims=True)
|
| 133 |
+
e = np.exp(x)
|
| 134 |
+
return e / np.sum(e, axis=1, keepdims=True)
|
| 135 |
+
|
| 136 |
def predict(self, text):
|
| 137 |
+
snippet = text[:50].replace("\n", " ")
|
| 138 |
+
print(f"\nProcessing input: '{snippet}...' ({len(text)} characters)")
|
| 139 |
+
|
| 140 |
+
# First rule based filter
|
| 141 |
if self.contains_banned_keyword(text):
|
| 142 |
print("Input rejected by keyword filter")
|
| 143 |
return {
|
| 144 |
+
"label": self.class_labels[0],
|
| 145 |
+
"probabilities": [1.0, 0.0] if len(self.class_labels) == 2 else [1.0] * len(self.class_labels),
|
| 146 |
}
|
| 147 |
+
|
|
|
|
|
|
|
|
|
|
| 148 |
# Preprocess
|
| 149 |
input_array = self.preprocess(text)
|
| 150 |
|
| 151 |
+
# Run inference. Use detected input name
|
| 152 |
+
outputs = self.session.run(None, {self.input_name: input_array})
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
# Post process
|
| 155 |
+
logits = outputs[0]
|
| 156 |
+
probs = self.softmax(logits)
|
| 157 |
+
pred_idx = int(np.argmax(probs))
|
| 158 |
+
label = self.class_labels[pred_idx] if pred_idx < len(self.class_labels) else str(pred_idx)
|
| 159 |
+
|
| 160 |
+
print(f"Model Passed. Result: {label} (Confidence: {probs[0][pred_idx]:.2%})")
|
| 161 |
+
return {"label": label, "probabilities": probs[0].tolist()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
|
| 164 |
+
# Gradio glue
|
| 165 |
+
def gradio_predict(text):
|
| 166 |
+
result = PIPELINE.predict(text)
|
| 167 |
+
return f"Prediction: {result['label']}\n"
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Create pipeline at import so the Space is ready
|
| 171 |
+
print("Initializing content filter pipeline...")
|
| 172 |
+
PIPELINE = ONNXInferencePipeline(repo_id="iimran/abuse-detector", repo_type="model")
|
| 173 |
+
print("Pipeline initialized successfully")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
# Required in Spaces. PORT is injected. Bind to 0.0.0.0
|
| 178 |
iface = gr.Interface(
|
| 179 |
fn=gradio_predict,
|
| 180 |
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
|
| 181 |
outputs="text",
|
| 182 |
title="Abuse Detector - Offensive Language Detector",
|
| 183 |
description=(
|
| 184 |
+
"Abuse detector identifies inappropriate content in text. "
|
| 185 |
+
"It analyzes input for Australian slang and abusive language. "
|
| 186 |
+
"It is trained on a compact dataset. It may not catch highly nuanced language, "
|
| 187 |
+
"but it detects common day to day offensive language."
|
| 188 |
+
),
|
| 189 |
examples=[
|
| 190 |
+
# Explicitly offensive examples
|
| 191 |
+
"Congrats, you fuckbrain arsehole, you have outdone yourself in stupidity. A real cock up of a human. Should we clap for your bollocks faced greatness or just pity you?",
|
| 192 |
+
"You are a mad bastard, but I would still grab a beer with you. Mess around all you like, you cockheaded legend. Your arsehole antics are bloody brilliant.",
|
| 193 |
+
"Your mother should have done better raising such a useless idiot.",
|
| 194 |
+
# Neutral or appropriate examples
|
| 195 |
+
"Hello HR, I hope this message finds you well. I am writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
|
| 196 |
+
"Thank you for your time and consideration. Please reach out if you need anything. I would be happy to discuss further.",
|
| 197 |
+
"The weather today is lovely, and I am looking forward to a productive day at work.",
|
| 198 |
+
# Mixed
|
| 199 |
+
"I appreciate your help, but honestly, you are such a clueless idiot sometimes. Still, thanks for trying."
|
| 200 |
+
],
|
|
|
|
|
|
|
| 201 |
)
|
|
|
|
|
|
|
| 202 |
print("Launching Gradio interface...")
|
|
|
|
| 203 |
iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|