Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
44af154
1
Parent(s):
8ec42ec
Bayes threshold added
Browse files- app.py +2 -2
- models/bayes.py +20 -5
app.py
CHANGED
|
@@ -66,7 +66,7 @@ def classify_text(classifier, text, threshold, output_format):
|
|
| 66 |
elif classifier == 'LSTM':
|
| 67 |
tags = get_tags_lstm(lstm, text, threshold)
|
| 68 |
else:
|
| 69 |
-
tags = get_tags_bayes(bayes, text)
|
| 70 |
|
| 71 |
tags = expand(tags)
|
| 72 |
predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}
|
|
@@ -100,7 +100,7 @@ with gr.Blocks(theme="default") as demo:
|
|
| 100 |
label="Input Text",
|
| 101 |
)
|
| 102 |
threshold = gr.Slider(0, 1, value=0.5, step=0.05,
|
| 103 |
-
label="Confidence Threshold
|
| 104 |
output_format = gr.Radio(
|
| 105 |
["Markdown", "JSON"],
|
| 106 |
value="Markdown",
|
|
|
|
| 66 |
elif classifier == 'LSTM':
|
| 67 |
tags = get_tags_lstm(lstm, text, threshold)
|
| 68 |
else:
|
| 69 |
+
tags = get_tags_bayes(bayes, text, threshold)
|
| 70 |
|
| 71 |
tags = expand(tags)
|
| 72 |
predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}
|
|
|
|
| 100 |
label="Input Text",
|
| 101 |
)
|
| 102 |
threshold = gr.Slider(0, 1, value=0.5, step=0.05,
|
| 103 |
+
label="Confidence Threshold")
|
| 104 |
output_format = gr.Radio(
|
| 105 |
["Markdown", "JSON"],
|
| 106 |
value="Markdown",
|
models/bayes.py
CHANGED
|
@@ -15,9 +15,9 @@ class NaiveBayesMultiClass(object):
|
|
| 15 |
self.classifiers = []
|
| 16 |
|
| 17 |
def load(self, path: str):
|
| 18 |
-
self.vectorizer = joblib.load(f
|
| 19 |
self.classifiers = [
|
| 20 |
-
joblib.load(f
|
| 21 |
]
|
| 22 |
|
| 23 |
def predict(self, X: typing.Iterable[str] | str, get_tags=False):
|
|
@@ -27,7 +27,7 @@ class NaiveBayesMultiClass(object):
|
|
| 27 |
by_class = [self.classifiers[i].predict(x) for i in range(self.n_classes)]
|
| 28 |
ans = []
|
| 29 |
|
| 30 |
-
for i in range(len(X)):
|
| 31 |
y = []
|
| 32 |
for j, cls in enumerate(self.classes):
|
| 33 |
if get_tags:
|
|
@@ -41,6 +41,21 @@ class NaiveBayesMultiClass(object):
|
|
| 41 |
def __call__(self, *args, **kwargs):
|
| 42 |
return self.predict(*args, **kwargs)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
def get_tags_bayes(model, text):
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.classifiers = []
|
| 16 |
|
| 17 |
def load(self, path: str):
|
| 18 |
+
self.vectorizer = joblib.load(f"{path}/vectorizer.joblib")
|
| 19 |
self.classifiers = [
|
| 20 |
+
joblib.load(f"{path}/class_{i}.joblib") for i in range(self.n_classes)
|
| 21 |
]
|
| 22 |
|
| 23 |
def predict(self, X: typing.Iterable[str] | str, get_tags=False):
|
|
|
|
| 27 |
by_class = [self.classifiers[i].predict(x) for i in range(self.n_classes)]
|
| 28 |
ans = []
|
| 29 |
|
| 30 |
+
for i in range(len(X)): # type: ignore
|
| 31 |
y = []
|
| 32 |
for j, cls in enumerate(self.classes):
|
| 33 |
if get_tags:
|
|
|
|
| 41 |
def __call__(self, *args, **kwargs):
|
| 42 |
return self.predict(*args, **kwargs)
|
| 43 |
|
| 44 |
+
def predict_proba(self, X: typing.Iterable[str] | str):
|
| 45 |
+
if type(X) == str:
|
| 46 |
+
return self.predict_proba([X])[0]
|
| 47 |
+
x = self.vectorizer.transform(X)
|
| 48 |
+
by_class = [self.classifiers[i].predict_proba(x) for i in range(self.n_classes)]
|
| 49 |
+
|
| 50 |
+
return [[by_class[j][i] for j in range(self.n_classes)] for i in range(len(X))] # type: ignore
|
| 51 |
+
|
| 52 |
|
| 53 |
+
def get_tags_bayes(model: NaiveBayesMultiClass, text: str, threshold : None | float =None):
|
| 54 |
+
if threshold is None:
|
| 55 |
+
return model.predict(text, True)
|
| 56 |
+
probs = model.predict_proba(text)
|
| 57 |
+
present = []
|
| 58 |
+
for i, cls in enumerate(model.classes):
|
| 59 |
+
if probs[i] >= threshold: # type: ignore
|
| 60 |
+
present.append(cls)
|
| 61 |
+
return present
|