Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -391,12 +391,11 @@
|
|
| 391 |
|
| 392 |
# if __name__ == "__main__":
|
| 393 |
# main()
|
| 394 |
-
|
| 395 |
import streamlit as st
|
| 396 |
import matplotlib.pyplot as plt
|
| 397 |
import torch
|
| 398 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
| 399 |
-
from transformers import
|
| 400 |
from datasets import load_dataset, Dataset
|
| 401 |
from evaluate import load as load_metric
|
| 402 |
from torch.utils.data import DataLoader
|
|
@@ -430,7 +429,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
|
|
| 430 |
del raw_datasets["unsupervised"]
|
| 431 |
|
| 432 |
if model_name == "google/byt5-small":
|
| 433 |
-
tokenizer =
|
| 434 |
|
| 435 |
def utf8_encode_function(examples):
|
| 436 |
encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
|
|
@@ -685,7 +684,7 @@ def main():
|
|
| 685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
| 686 |
|
| 687 |
if model_name == "google/byt5-small":
|
| 688 |
-
net =
|
| 689 |
else:
|
| 690 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 691 |
|
|
@@ -790,5 +789,3 @@ def main():
|
|
| 790 |
|
| 791 |
if __name__ == "__main__":
|
| 792 |
main()
|
| 793 |
-
|
| 794 |
-
|
|
|
|
| 391 |
|
| 392 |
# if __name__ == "__main__":
|
| 393 |
# main()
|
|
|
|
| 394 |
import streamlit as st
|
| 395 |
import matplotlib.pyplot as plt
|
| 396 |
import torch
|
| 397 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
| 398 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 399 |
from datasets import load_dataset, Dataset
|
| 400 |
from evaluate import load as load_metric
|
| 401 |
from torch.utils.data import DataLoader
|
|
|
|
| 429 |
del raw_datasets["unsupervised"]
|
| 430 |
|
| 431 |
if model_name == "google/byt5-small":
|
| 432 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 433 |
|
| 434 |
def utf8_encode_function(examples):
|
| 435 |
encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
|
|
|
|
| 684 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
| 685 |
|
| 686 |
if model_name == "google/byt5-small":
|
| 687 |
+
net = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
|
| 688 |
else:
|
| 689 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 690 |
|
|
|
|
| 789 |
|
| 790 |
if __name__ == "__main__":
|
| 791 |
main()
|
|
|
|
|
|