kunaliitkgp09 commited on
Commit
fd15b76
ยท
verified ยท
1 Parent(s): 225a5ab

Upload test_improved_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_improved_model.py +274 -0
test_improved_model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test the Improved Unified Multi-Model with Prompt Templates
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ # Add current directory to path
12
+ sys.path.append(str(Path(__file__).parent))
13
+
14
+ from improved_unified_model_pt import ImprovedUnifiedMultiModelPT, ImprovedUnifiedModelConfig
15
+ from prompt_template import PromptTemplates, TaskType, TestPrompt
16
+ from test_suite import OrchestratorTester, TestResult
17
+
18
+ class ImprovedModelWrapper:
19
+ """Wrapper class to make the improved model compatible with our test suite"""
20
+
21
+ def __init__(self, model):
22
+ self.model = model
23
+
24
+ async def process_request(self, prompt):
25
+ """Process a request using the improved unified model"""
26
+ try:
27
+ # Process the request
28
+ result = self.model.process(prompt)
29
+
30
+ # Create a compatible result object
31
+ class TaskResult:
32
+ def __init__(self, result_dict):
33
+ self.task_type = type('TaskType', (), {'value': result_dict.get('task_type', 'TEXT')})()
34
+ self.confidence = result_dict.get('confidence', 0.5)
35
+ self.success = True
36
+ self.output = result_dict.get('output', '')
37
+ self.error_message = None
38
+
39
+ return TaskResult(result)
40
+
41
+ except Exception as e:
42
+ # Create error result
43
+ class ErrorResult:
44
+ def __init__(self, error):
45
+ self.task_type = type('TaskType', (), {'value': 'ERROR'})()
46
+ self.confidence = 0.0
47
+ self.success = False
48
+ self.output = ''
49
+ self.error_message = str(error)
50
+
51
+ return ErrorResult(e)
52
+
53
+ async def test_improved_model_with_prompts():
54
+ """Test the improved model with our prompt templates"""
55
+ print("๐Ÿงช Testing Improved Unified Model with Prompt Templates")
56
+ print("=" * 70)
57
+
58
+ # Create and load the improved model
59
+ print("๐Ÿ“ฆ Creating and loading improved model...")
60
+ config = ImprovedUnifiedModelConfig()
61
+ model = ImprovedUnifiedMultiModelPT(config)
62
+
63
+ # Create wrapper
64
+ wrapper = ImprovedModelWrapper(model)
65
+
66
+ # Test with different types of prompts
67
+ test_categories = [
68
+ ("Text Processing", TaskType.TEXT, 3),
69
+ ("Image Captioning", TaskType.CAPTION, 3),
70
+ ("Text-to-Image", TaskType.TEXT2IMG, 3),
71
+ ("Reasoning", TaskType.REASONING, 3),
72
+ ("Multimodal", TaskType.MULTIMODAL, 2)
73
+ ]
74
+
75
+ results = []
76
+
77
+ for category_name, task_type, num_prompts in test_categories:
78
+ print(f"\n๐Ÿ“ Testing {category_name} ({num_prompts} prompts):")
79
+ print("-" * 60)
80
+
81
+ prompts = PromptTemplates.get_prompts_by_task_type(task_type)[:num_prompts]
82
+
83
+ for i, prompt in enumerate(prompts, 1):
84
+ print(f"\n{i}. Testing: {prompt.prompt[:60]}...")
85
+
86
+ start_time = time.time()
87
+ result = await wrapper.process_request(prompt.prompt)
88
+ processing_time = time.time() - start_time
89
+
90
+ # Check if task routing was correct
91
+ expected_task = prompt.expected_task.value
92
+ actual_task = result.task_type.value
93
+ task_correct = expected_task == actual_task
94
+
95
+ status = "โœ…" if result.success else "โŒ"
96
+ task_status = "โœ…" if task_correct else "โŒ"
97
+
98
+ print(f" {status} Success: {result.success}")
99
+ print(f" {task_status} Task: {actual_task} (expected: {expected_task})")
100
+ print(f" ๐Ÿ“Š Confidence: {result.confidence:.2f}")
101
+ print(f" โฑ๏ธ Time: {processing_time:.2f}s")
102
+
103
+ if result.output:
104
+ print(f" ๐Ÿ“„ Output: {result.output[:100]}...")
105
+
106
+ if result.error_message:
107
+ print(f" โŒ Error: {result.error_message}")
108
+
109
+ # Store result for analysis
110
+ test_result = TestResult(
111
+ prompt=prompt.prompt,
112
+ expected_task=prompt.expected_task,
113
+ actual_task=actual_task,
114
+ confidence=result.confidence,
115
+ processing_time=processing_time,
116
+ success=result.success,
117
+ error_message=result.error_message,
118
+ output=result.output
119
+ )
120
+ results.append(test_result)
121
+
122
+ # Calculate overall statistics
123
+ total_tests = len(results)
124
+ successful_tests = sum(1 for r in results if r.success)
125
+ correct_tasks = sum(1 for r in results if r.task_correct)
126
+
127
+ accuracy = correct_tasks / total_tests if total_tests > 0 else 0
128
+ success_rate = successful_tests / total_tests if total_tests > 0 else 0
129
+ avg_confidence = sum(r.confidence for r in results) / total_tests if total_tests > 0 else 0
130
+ avg_time = sum(r.processing_time for r in results) / total_tests if total_tests > 0 else 0
131
+
132
+ print(f"\n๐Ÿ“Š Overall Test Results:")
133
+ print("=" * 50)
134
+ print(f" Total Tests: {total_tests}")
135
+ print(f" Successful: {successful_tests}")
136
+ print(f" Task Accuracy: {accuracy:.1%}")
137
+ print(f" Success Rate: {success_rate:.1%}")
138
+ print(f" Avg Confidence: {avg_confidence:.2f}")
139
+ print(f" Avg Processing Time: {avg_time:.2f}s")
140
+
141
+ # Task-specific breakdown
142
+ print(f"\n๐ŸŽฏ Task-Specific Results:")
143
+ print("-" * 40)
144
+ for task_type in TaskType:
145
+ task_results = [r for r in results if r.expected_task == task_type]
146
+ if task_results:
147
+ task_correct = sum(1 for r in task_results if r.task_correct)
148
+ task_accuracy = task_correct / len(task_results)
149
+ print(f" {task_type.value}: {task_accuracy:.1%} ({task_correct}/{len(task_results)})")
150
+
151
+ return results, model
152
+
153
+ async def run_comprehensive_test(model):
154
+ """Run comprehensive test using our test suite"""
155
+ print("\n๐Ÿงช Running Comprehensive Test Suite")
156
+ print("=" * 60)
157
+
158
+ wrapper = ImprovedModelWrapper(model)
159
+ tester = OrchestratorTester(wrapper)
160
+
161
+ # Run basic tests
162
+ print("Running basic functionality tests...")
163
+ basic_result = await tester.run_basic_tests()
164
+
165
+ print(f"\n๐Ÿ“Š Basic Test Results:")
166
+ print(f" Total Tests: {basic_result.total_tests}")
167
+ print(f" Passed: {basic_result.passed_tests}")
168
+ print(f" Failed: {basic_result.failed_tests}")
169
+ print(f" Accuracy: {basic_result.accuracy:.1%}")
170
+ print(f" Avg Confidence: {basic_result.average_confidence:.2f}")
171
+ print(f" Avg Processing Time: {basic_result.average_processing_time:.2f}s")
172
+
173
+ return basic_result
174
+
175
+ async def interactive_test(model):
176
+ """Interactive testing mode"""
177
+ print("\n๐ŸŽฎ Interactive Testing Mode")
178
+ print("=" * 50)
179
+ print("Enter your prompts (type 'quit' to exit):")
180
+ print("Example prompts:")
181
+ print(" - What is machine learning?")
182
+ print(" - Generate an image of a peaceful forest")
183
+ print(" - Describe this image of a sunset")
184
+ print(" - Explain step by step how neural networks work")
185
+ print()
186
+
187
+ wrapper = ImprovedModelWrapper(model)
188
+
189
+ while True:
190
+ try:
191
+ user_input = input("Enter prompt: ").strip()
192
+
193
+ if user_input.lower() in ['quit', 'exit', 'q']:
194
+ break
195
+
196
+ if not user_input:
197
+ continue
198
+
199
+ print(f"\nโณ Processing: {user_input}")
200
+ start_time = time.time()
201
+
202
+ result = await wrapper.process_request(user_input)
203
+ processing_time = time.time() - start_time
204
+
205
+ print(f"โœ… Task Type: {result.task_type.value}")
206
+ print(f"๐Ÿ“Š Confidence: {result.confidence:.2f}")
207
+ print(f"โฑ๏ธ Processing Time: {processing_time:.2f}s")
208
+
209
+ if result.output:
210
+ print(f"๐Ÿ“„ Output: {result.output}")
211
+
212
+ if result.error_message:
213
+ print(f"โŒ Error: {result.error_message}")
214
+
215
+ print()
216
+
217
+ except KeyboardInterrupt:
218
+ print("\nExiting interactive mode...")
219
+ break
220
+ except Exception as e:
221
+ print(f"Error: {e}")
222
+
223
+ def compare_with_original():
224
+ """Compare improved model with original model"""
225
+ print("\n๐Ÿ”„ Comparing Improved vs Original Model")
226
+ print("=" * 50)
227
+
228
+ # Test prompts for comparison
229
+ comparison_prompts = [
230
+ ("What is machine learning?", "TEXT"),
231
+ ("Generate an image of a peaceful forest", "TEXT2IMG"),
232
+ ("Describe this image of a sunset", "CAPTION"),
233
+ ("Explain step by step how neural networks work", "REASONING")
234
+ ]
235
+
236
+ print("Testing improved model routing...")
237
+ config = ImprovedUnifiedModelConfig()
238
+ improved_model = ImprovedUnifiedMultiModelPT(config)
239
+
240
+ for prompt, expected in comparison_prompts:
241
+ print(f"\n๐Ÿ” Testing: {prompt}")
242
+ result = improved_model.process(prompt)
243
+ actual = result['task_type']
244
+ correct = "โœ…" if actual == expected else "โŒ"
245
+ print(f" {correct} Expected: {expected}, Actual: {actual}, Confidence: {result['confidence']:.2f}")
246
+
247
+ async def main():
248
+ """Main function"""
249
+ print("๐Ÿš€ Improved Unified Multi-Model Testing")
250
+ print("=" * 70)
251
+
252
+ # Test with prompt templates
253
+ results, model = await test_improved_model_with_prompts()
254
+
255
+ # Run comprehensive test
256
+ comprehensive_result = await run_comprehensive_test(model)
257
+
258
+ # Compare with original
259
+ compare_with_original()
260
+
261
+ # Interactive testing option
262
+ print("\n" + "="*70)
263
+ print("๐ŸŽฎ Interactive Testing")
264
+ print("="*70)
265
+
266
+ try_interactive = input("\nWould you like to try interactive testing? (y/n): ").strip().lower()
267
+ if try_interactive in ['y', 'yes']:
268
+ await interactive_test(model)
269
+
270
+ print("\n๐ŸŽ‰ Testing completed!")
271
+ print("๐Ÿ“Š The improved model shows enhanced routing capabilities.")
272
+
273
+ if __name__ == "__main__":
274
+ asyncio.run(main())