iimran commited on
Commit
0c698cb
·
verified ·
1 Parent(s): 63c9bcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -119
app.py CHANGED
@@ -7,182 +7,197 @@ import onnxruntime as ort
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
 
 
10
  class ONNXInferencePipeline:
11
- def __init__(self, repo_id):
12
- # Retrieve the Hugging Face token from the environment variable
13
  hf_token = os.getenv("HF_TOKEN")
14
- if hf_token is None:
15
- raise ValueError("HF_TOKEN environment variable is not set.")
16
 
17
  # Load banned keywords list
18
  self.banned_keywords = self.load_banned_keywords()
19
  print(f"Loaded {len(self.banned_keywords)} banned keywords")
20
-
21
- # Download files from Hugging Face Hub using the token
22
- self.onnx_path = hf_hub_download(repo_id=repo_id, filename="model.onnx", use_auth_token=hf_token)
23
- self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json", use_auth_token=hf_token)
24
- self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json", use_auth_token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Load configuration
27
- with open(self.config_path) as f:
28
  self.config = json.load(f)
29
 
30
  # Initialize tokenizer
31
  self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
32
- self.max_len = 256
33
 
34
  # Initialize ONNX runtime session
35
- self.session = ort.InferenceSession(self.onnx_path)
36
- self.providers = ['CPUExecutionProvider'] # Use CUDA if available
37
- if 'CUDAExecutionProvider' in ort.get_available_providers():
38
- self.providers = ['CUDAExecutionProvider']
39
- self.session.set_providers(self.providers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def load_banned_keywords(self):
42
- # For testing purposes, using a small list
43
- # In production, load your full list
44
-
 
 
 
45
  code_str = os.getenv("banned")
46
  if not code_str:
47
- raise Exception("Environment variable 'banned' is not set. Please set it with your banned keywords list.")
 
48
 
49
- # Create a local namespace to execute the code
 
 
 
 
 
 
 
 
50
  local_vars = {}
51
-
52
- # Wrap the code in a function to allow return statements
53
  wrapped_code = f"""
54
  def get_banned_keywords():
55
  {textwrap.indent(code_str, ' ')}
56
  """
57
-
58
  try:
59
- # Execute the wrapped code
60
- exec(wrapped_code, globals(), local_vars)
61
- # Call the function to get the banned keywords
62
- return local_vars['get_banned_keywords']()
 
 
63
  except Exception as e:
64
- print(f"Error loading banned keywords: {e}")
65
- # Return a default empty list if there's an error
66
  return []
67
 
68
  def contains_banned_keyword(self, text):
69
- """Check if the input text contains any banned keywords."""
70
  text_lower = text.lower()
71
-
72
- # Split the text into words
73
- words = ''.join(c if c.isalnum() else ' ' for c in text_lower).split()
74
-
75
  for keyword in self.banned_keywords:
76
- keyword_lower = keyword.lower()
77
-
78
- # Check if keyword is a whole word in the text
79
- if keyword_lower in words:
80
  print(f"Keyword detected: '{keyword}'")
81
  return True
82
-
83
- print("Keywords Passed - No inappropriate keywords found")
84
  return False
85
 
86
  def preprocess(self, text):
87
  encoding = self.tokenizer.encode(text)
88
- ids = encoding.ids[:self.max_len]
89
  padding = [0] * (self.max_len - len(ids))
90
  return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
91
 
 
 
 
 
 
 
 
92
  def predict(self, text):
93
- print(f"\nProcessing input: '{text[:50]}...' ({len(text)} characters)")
94
-
95
- # First check if the text contains any banned keywords
 
96
  if self.contains_banned_keyword(text):
97
  print("Input rejected by keyword filter")
98
  return {
99
- 'label': 'Inappropriate Content',
100
- 'probabilities': [1.0, 0.0] # Assuming [inappropriate, appropriate]
101
  }
102
-
103
- # If no banned keywords found, proceed with model prediction
104
- print("Running ML model for classification...")
105
-
106
  # Preprocess
107
  input_array = self.preprocess(text)
108
 
109
- # Run inference
110
- results = self.session.run(
111
- None,
112
- {'input': input_array}
113
- )
114
 
115
- # Post-process
116
- logits = results[0]
117
- probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
118
- predicted_class = int(np.argmax(probabilities))
119
-
120
- # Map to labels
121
- class_labels = ['Inappropriate Content', 'Appropriate']
122
-
123
- # Log model result
124
- print(f"Model Passed - Result: {class_labels[predicted_class]} (Confidence: {probabilities[0][predicted_class]:.2%})")
125
-
126
- return {
127
- 'label': class_labels[predicted_class],
128
- 'probabilities': probabilities[0].tolist()
129
- }
130
-
131
- # Example usage
132
- if __name__ == "__main__":
133
- # Initialize the pipeline with the Hugging Face repository ID
134
- print("Initializing content filter pipeline...")
135
- pipeline = ONNXInferencePipeline(repo_id="iimran/abuse-detector")
136
- print("Pipeline initialized successfully")
137
-
138
- # Example texts for testing
139
- example_texts = [
140
- "You're a worthless piece of garbage who should die",
141
- "Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team."
142
- ]
143
-
144
- for text in example_texts:
145
- result = pipeline.predict(text)
146
- print(f"Input: {text[:50]}...")
147
- print(f"Prediction: {result['label']} ")
148
- print("-" * 80)
149
-
150
- # Define a function for Gradio to use
151
- def gradio_predict(text):
152
- result = pipeline.predict(text)
153
- return (
154
- f"Prediction: {result['label']} \n"
155
- )
156
 
157
- # Create a Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  iface = gr.Interface(
159
  fn=gradio_predict,
160
  inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
161
  outputs="text",
162
  title="Abuse Detector - Offensive Language Detector",
163
  description=(
164
- "Abuse detector is designed to identify inappropriate content in text. "
165
- "It analyzes input for Australian Slang language and abuses. "
166
- "While it's trained on a compact dataset and may not catch highly nuanced or sophisticated language, "
167
- "it effectively detects day-to-day offensive language commonly used in conversations."
168
- ),
169
  examples=[
170
- # Explicitly offensive examples
171
- "Congrats, you fuckbrain arsehole, you've outdone yourself in stupidity. A real cock-up of a human—should we clap for your bollocks-faced greatness or just pity you?",
172
- "You're a mad bastard, but I'd still grab a beer with you! Fuck around all you like, you cockheaded legend—your arsehole antics are bloody brilliant.",
173
- "Your mother should have done better raising such a useless idiot.",
174
-
175
- # Neutral or appropriate examples
176
- "Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
177
- "Thank you for your time and consideration. Please don't hesitate to reach out if you need additional information—I'd be happy to discuss further. Looking forward to hearing from you soon!",
178
- "The weather today is lovely, and I'm looking forward to a productive day at work.",
179
-
180
- # Mixed examples (some offensive, some neutral)
181
- "I appreciate your help, but honestly, you're such a clueless idiot sometimes. Still, thanks for trying."
182
- ]
183
  )
184
-
185
- # Launch the Gradio app
186
  print("Launching Gradio interface...")
187
- # at the very bottom where you launch Gradio
188
  iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
 
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
 
10
+
11
  class ONNXInferencePipeline:
12
+ def __init__(self, repo_id, repo_type="model"):
13
+ # Read token from env. In a Space, HF_TOKEN can be set in the Secrets panel.
14
  hf_token = os.getenv("HF_TOKEN")
 
 
15
 
16
  # Load banned keywords list
17
  self.banned_keywords = self.load_banned_keywords()
18
  print(f"Loaded {len(self.banned_keywords)} banned keywords")
19
+
20
+ # Download artifacts. Newer huggingface_hub uses token=, not use_auth_token=
21
+ self.onnx_path = hf_hub_download(
22
+ repo_id=repo_id,
23
+ filename="model.onnx",
24
+ token=hf_token,
25
+ repo_type=repo_type
26
+ )
27
+ self.tokenizer_path = hf_hub_download(
28
+ repo_id=repo_id,
29
+ filename="train_bpe_tokenizer.json",
30
+ token=hf_token,
31
+ repo_type=repo_type
32
+ )
33
+ self.config_path = hf_hub_download(
34
+ repo_id=repo_id,
35
+ filename="hyperparameters.json",
36
+ token=hf_token,
37
+ repo_type=repo_type
38
+ )
39
 
40
  # Load configuration
41
+ with open(self.config_path, "r") as f:
42
  self.config = json.load(f)
43
 
44
  # Initialize tokenizer
45
  self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
46
+ self.max_len = int(self.config.get("max_len", 256))
47
 
48
  # Initialize ONNX runtime session
49
+ # Spaces CPU runtime typically uses CPUExecutionProvider
50
+ providers = ort.get_available_providers()
51
+ if "CUDAExecutionProvider" in providers:
52
+ use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
53
+ else:
54
+ use_providers = ["CPUExecutionProvider"]
55
+
56
+ sess_options = ort.SessionOptions()
57
+ # Reduce memory and improve cold start a bit
58
+ sess_options.enable_mem_pattern = False
59
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
60
+
61
+ self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options, providers=use_providers)
62
+
63
+ # Cache model input name to avoid mismatches like input vs input_ids
64
+ self.input_name = self.session.get_inputs()[0].name
65
+ print(f"ONNX model input name detected: {self.input_name}")
66
+
67
+ # If you want label order from config, you can read it
68
+ self.class_labels = self.config.get("class_labels", ["Inappropriate Content", "Appropriate"])
69
 
70
  def load_banned_keywords(self):
71
+ """
72
+ Load banned keywords from env var named 'banned'.
73
+ Supports two formats:
74
+ 1) Python code snippet that returns a list (your current method)
75
+ 2) JSON array of strings
76
+ """
77
  code_str = os.getenv("banned")
78
  if not code_str:
79
+ print("Environment variable 'banned' is not set. Using empty list.")
80
+ return []
81
 
82
+ # Try JSON first
83
+ try:
84
+ parsed = json.loads(code_str)
85
+ if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
86
+ return parsed
87
+ except Exception:
88
+ pass
89
+
90
+ # Fallback to executable code that returns a list
91
  local_vars = {}
 
 
92
  wrapped_code = f"""
93
  def get_banned_keywords():
94
  {textwrap.indent(code_str, ' ')}
95
  """
 
96
  try:
97
+ exec(wrapped_code, {}, local_vars)
98
+ result = local_vars["get_banned_keywords"]()
99
+ if isinstance(result, list):
100
+ return [str(x) for x in result]
101
+ print("Loaded banned keywords code did not return a list. Using empty list.")
102
+ return []
103
  except Exception as e:
104
+ print(f"Error loading banned keywords from code: {e}")
 
105
  return []
106
 
107
  def contains_banned_keyword(self, text):
108
+ """Check if the input text contains any banned keywords as whole words."""
109
  text_lower = text.lower()
110
+ words = "".join(c if c.isalnum() else " " for c in text_lower).split()
111
+ word_set = set(words)
112
+
 
113
  for keyword in self.banned_keywords:
114
+ kw = str(keyword).lower().strip()
115
+ if not kw:
116
+ continue
117
+ if kw in word_set:
118
  print(f"Keyword detected: '{keyword}'")
119
  return True
120
+ print("Keywords Passed. No inappropriate keywords found")
 
121
  return False
122
 
123
  def preprocess(self, text):
124
  encoding = self.tokenizer.encode(text)
125
+ ids = encoding.ids[: self.max_len]
126
  padding = [0] * (self.max_len - len(ids))
127
  return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
128
 
129
+ @staticmethod
130
+ def softmax(logits):
131
+ # Numerically stable softmax
132
+ x = logits - np.max(logits, axis=1, keepdims=True)
133
+ e = np.exp(x)
134
+ return e / np.sum(e, axis=1, keepdims=True)
135
+
136
  def predict(self, text):
137
+ snippet = text[:50].replace("\n", " ")
138
+ print(f"\nProcessing input: '{snippet}...' ({len(text)} characters)")
139
+
140
+ # First rule based filter
141
  if self.contains_banned_keyword(text):
142
  print("Input rejected by keyword filter")
143
  return {
144
+ "label": self.class_labels[0],
145
+ "probabilities": [1.0, 0.0] if len(self.class_labels) == 2 else [1.0] * len(self.class_labels),
146
  }
147
+
 
 
 
148
  # Preprocess
149
  input_array = self.preprocess(text)
150
 
151
+ # Run inference. Use detected input name
152
+ outputs = self.session.run(None, {self.input_name: input_array})
 
 
 
153
 
154
+ # Post process
155
+ logits = outputs[0]
156
+ probs = self.softmax(logits)
157
+ pred_idx = int(np.argmax(probs))
158
+ label = self.class_labels[pred_idx] if pred_idx < len(self.class_labels) else str(pred_idx)
159
+
160
+ print(f"Model Passed. Result: {label} (Confidence: {probs[0][pred_idx]:.2%})")
161
+ return {"label": label, "probabilities": probs[0].tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+
164
+ # Gradio glue
165
+ def gradio_predict(text):
166
+ result = PIPELINE.predict(text)
167
+ return f"Prediction: {result['label']}\n"
168
+
169
+
170
+ # Create pipeline at import so the Space is ready
171
+ print("Initializing content filter pipeline...")
172
+ PIPELINE = ONNXInferencePipeline(repo_id="iimran/abuse-detector", repo_type="model")
173
+ print("Pipeline initialized successfully")
174
+
175
+
176
+ if __name__ == "__main__":
177
+ # Required in Spaces. PORT is injected. Bind to 0.0.0.0
178
  iface = gr.Interface(
179
  fn=gradio_predict,
180
  inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
181
  outputs="text",
182
  title="Abuse Detector - Offensive Language Detector",
183
  description=(
184
+ "Abuse detector identifies inappropriate content in text. "
185
+ "It analyzes input for Australian slang and abusive language. "
186
+ "It is trained on a compact dataset. It may not catch highly nuanced language, "
187
+ "but it detects common day to day offensive language."
188
+ ),
189
  examples=[
190
+ # Explicitly offensive examples
191
+ "Congrats, you fuckbrain arsehole, you have outdone yourself in stupidity. A real cock up of a human. Should we clap for your bollocks faced greatness or just pity you?",
192
+ "You are a mad bastard, but I would still grab a beer with you. Mess around all you like, you cockheaded legend. Your arsehole antics are bloody brilliant.",
193
+ "Your mother should have done better raising such a useless idiot.",
194
+ # Neutral or appropriate examples
195
+ "Hello HR, I hope this message finds you well. I am writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
196
+ "Thank you for your time and consideration. Please reach out if you need anything. I would be happy to discuss further.",
197
+ "The weather today is lovely, and I am looking forward to a productive day at work.",
198
+ # Mixed
199
+ "I appreciate your help, but honestly, you are such a clueless idiot sometimes. Still, thanks for trying."
200
+ ],
 
 
201
  )
 
 
202
  print("Launching Gradio interface...")
 
203
  iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))