actio-ui-7b-sft / sanity.py
chc012's picture
Add sanity check script and example screenshot with LFS tracking
85ca5ad
raw
history blame
4.2 kB
import base64
import sys
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
from PIL import Image
def encode_image(image_path: str) -> str:
"""Encode image to base64 string for model input."""
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode()
def load_model(
model_path: str,
) -> tuple[AutoModelForVision2Seq, AutoTokenizer, AutoProcessor]:
"""Load OpenCUA model, tokenizer, and image processor."""
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
image_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer, image_processor
def create_grounding_messages(image_path: str, instruction: str) -> list[dict]:
"""Create chat messages for GUI grounding task."""
system_prompt = (
"You are a GUI agent. You are given a task and a screenshot of the screen. "
"You need to perform a series of pyautogui actions to complete the task."
)
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Please perform the following task by providing the action and the coordinates: "
+ instruction,
},
{
"type": "image",
"image": f"data:image/png;base64,{encode_image(image_path)}",
},
],
},
]
return messages
def run_inference(
model: AutoModelForVision2Seq,
tokenizer: AutoTokenizer,
image_processor: AutoProcessor,
messages: list[dict],
image_path: str,
) -> str:
"""Run inference on the model."""
# Prepare text from messages
text = image_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Open image
image = Image.open(image_path).convert("RGB")
# Process inputs using the processor
inputs = image_processor(
text=[text], images=[image], padding=True, return_tensors="pt"
)
# Move inputs to model device
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False,
)
# Decode output (skip the input tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = image_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return output_text
def main():
"""Main function to run the sanity check."""
# Configuration
model_path = "Uniphore/actio-ui-7b-sft" # or other model variants
image_path = "screenshot.png"
instruction = "Click on the submit button"
# Check if custom instruction provided
if len(sys.argv) > 1:
instruction = " ".join(sys.argv[1:])
print(f"Loading model from: {model_path}")
try:
model, tokenizer, image_processor = load_model(model_path)
print("✓ Model loaded successfully")
except Exception as e:
print(f"✗ Error loading model: {e}")
return 1
print(f"Processing image: {image_path}")
print(f"Instruction: {instruction}")
try:
messages = create_grounding_messages(image_path, instruction)
result = run_inference(model, tokenizer, image_processor, messages, image_path)
print("\n" + "=" * 60)
print("MODEL OUTPUT:")
print("=" * 60)
print(result)
print("=" * 60)
return 0
except Exception as e:
print(f"✗ Error during inference: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())