Lxz20071231 commited on
Commit
44af154
·
1 Parent(s): 8ec42ec

Bayes threshold added

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. models/bayes.py +20 -5
app.py CHANGED
@@ -66,7 +66,7 @@ def classify_text(classifier, text, threshold, output_format):
66
  elif classifier == 'LSTM':
67
  tags = get_tags_lstm(lstm, text, threshold)
68
  else:
69
- tags = get_tags_bayes(bayes, text)
70
 
71
  tags = expand(tags)
72
  predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}
@@ -100,7 +100,7 @@ with gr.Blocks(theme="default") as demo:
100
  label="Input Text",
101
  )
102
  threshold = gr.Slider(0, 1, value=0.5, step=0.05,
103
- label="Confidence Threshold (not available for Naïve Bayes)")
104
  output_format = gr.Radio(
105
  ["Markdown", "JSON"],
106
  value="Markdown",
 
66
  elif classifier == 'LSTM':
67
  tags = get_tags_lstm(lstm, text, threshold)
68
  else:
69
+ tags = get_tags_bayes(bayes, text, threshold)
70
 
71
  tags = expand(tags)
72
  predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}
 
100
  label="Input Text",
101
  )
102
  threshold = gr.Slider(0, 1, value=0.5, step=0.05,
103
+ label="Confidence Threshold")
104
  output_format = gr.Radio(
105
  ["Markdown", "JSON"],
106
  value="Markdown",
models/bayes.py CHANGED
@@ -15,9 +15,9 @@ class NaiveBayesMultiClass(object):
15
  self.classifiers = []
16
 
17
  def load(self, path: str):
18
- self.vectorizer = joblib.load(f'{path}/vectorizer.joblib')
19
  self.classifiers = [
20
- joblib.load(f'{path}/class_{i}.joblib') for i in range(self.n_classes)
21
  ]
22
 
23
  def predict(self, X: typing.Iterable[str] | str, get_tags=False):
@@ -27,7 +27,7 @@ class NaiveBayesMultiClass(object):
27
  by_class = [self.classifiers[i].predict(x) for i in range(self.n_classes)]
28
  ans = []
29
 
30
- for i in range(len(X)):
31
  y = []
32
  for j, cls in enumerate(self.classes):
33
  if get_tags:
@@ -41,6 +41,21 @@ class NaiveBayesMultiClass(object):
41
  def __call__(self, *args, **kwargs):
42
  return self.predict(*args, **kwargs)
43
 
 
 
 
 
 
 
 
 
44
 
45
- def get_tags_bayes(model, text):
46
- return model.predict(clean(text), True)
 
 
 
 
 
 
 
 
15
  self.classifiers = []
16
 
17
  def load(self, path: str):
18
+ self.vectorizer = joblib.load(f"{path}/vectorizer.joblib")
19
  self.classifiers = [
20
+ joblib.load(f"{path}/class_{i}.joblib") for i in range(self.n_classes)
21
  ]
22
 
23
  def predict(self, X: typing.Iterable[str] | str, get_tags=False):
 
27
  by_class = [self.classifiers[i].predict(x) for i in range(self.n_classes)]
28
  ans = []
29
 
30
+ for i in range(len(X)): # type: ignore
31
  y = []
32
  for j, cls in enumerate(self.classes):
33
  if get_tags:
 
41
  def __call__(self, *args, **kwargs):
42
  return self.predict(*args, **kwargs)
43
 
44
+ def predict_proba(self, X: typing.Iterable[str] | str):
45
+ if type(X) == str:
46
+ return self.predict_proba([X])[0]
47
+ x = self.vectorizer.transform(X)
48
+ by_class = [self.classifiers[i].predict_proba(x) for i in range(self.n_classes)]
49
+
50
+ return [[by_class[j][i] for j in range(self.n_classes)] for i in range(len(X))] # type: ignore
51
+
52
 
53
+ def get_tags_bayes(model: NaiveBayesMultiClass, text: str, threshold : None | float =None):
54
+ if threshold is None:
55
+ return model.predict(text, True)
56
+ probs = model.predict_proba(text)
57
+ present = []
58
+ for i, cls in enumerate(model.classes):
59
+ if probs[i] >= threshold: # type: ignore
60
+ present.append(cls)
61
+ return present