File size: 2,388 Bytes
7336cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import torch
import numpy as np
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification

DEFAULT_THRESHOLD = 0.5

def preprocess_text(text, anonymize_mentions=True):
    if anonymize_mentions:
        text = re.sub(r'@\w+', '@anonymized_account', text)
    return text

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("text", type=str, help="Text to classify")
    parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier", 
                        help="Path to model or HF model ID")
    parser.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD, 
                        help="Classification threshold (default: 0.5)")
    parser.add_argument("--no-anonymize", action="store_true",
                        help="Disable mention anonymization (not recommended)")
    args = parser.parse_args()

    print(f"Loading model from: {args.model_path}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_path)
    model.eval()
    
    labels = [model.config.id2label[i] for i in range(model.config.num_labels)]

    anonymize = not args.no_anonymize
    processed_text = preprocess_text(args.text, anonymize_mentions=anonymize)
    
    if anonymize and processed_text != args.text:
        print(f"Preprocessed text: {processed_text}")
    
    print(f"\nInput text: {args.text}\n")

    inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits.squeeze().numpy()
    
    probabilities = 1 / (1 + np.exp(-logits))
    predictions = probabilities > args.threshold

    assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]]
    
    if assigned_labels:
        print("Assigned Labels:")
        print("-" * 40)
        for label in assigned_labels:
            print(f"  {label}")
        print()
    else:
        print("No labels assigned (all below threshold)\n")
    
    print("All Labels (with probabilities):")
    print("-" * 40)
    for i, label in enumerate(labels):
        status = "✓" if predictions[i] else " "
        print(f"{status} {label:15s}: {probabilities[i]:.4f}")

if __name__ == "__main__":
    main()