File size: 2,841 Bytes
d4a7de9
 
 
 
142339c
d4a7de9
e65359a
d4a7de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614cc6d
 
 
d4a7de9
 
 
 
 
 
 
 
 
 
 
 
b80570e
 
 
 
 
 
 
 
 
 
 
d4a7de9
 
 
 
 
 
 
 
 
 
 
 
 
b80570e
 
d4a7de9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Dict, Any
import base64
import tempfile
import os
os.environ["TRANSFORMERS_NO_FLASH_ATTN_2"] = "1"
import sys


# 确保能导入 videollama2 模块(模型代码需要放同目录或已安装)
sys.path.append('./')

from videollama2 import model_init, mm_infer
from videollama2.utils import disable_torch_init

class EndpointHandler:
    def __init__(self, path=""):
        # 关闭torch自动初始化,避免重复加载
        disable_torch_init()
        # 模型路径,如果HF环境传入的path为空,就用默认的官方仓库地址
        self.model_path = path or "DAMO-NLP-SG/VideoLLaMA2-7B-16F"
        # 加载模型、处理器、分词器
        self.model, self.processor, self.tokenizer = model_init(self.model_path)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        期待输入数据格式:
        {
            "video": "<base64字符串>",  # 视频文件base64编码
            "prompt": "描述视频内容的自然语言指令"
        }
        或者
        {
            "image": "<base64字符串>",  # 图片文件base64编码
            "prompt": "描述图片内容的自然语言指令"
        }
        """
        # Hugging Face接口会把真正的输入放在inputs字段里
        data = data.get("inputs", data)

        # 判断输入模态
        if "video" in data:
            modal = "video"
            file_b64 = data["video"]
        elif "image" in data:
            modal = "image"
            file_b64 = data["image"]
        else:
            return {"error": "请求必须包含 'video' 或 'image' 字段"}

        prompt = data.get("prompt", "Describe the content.")

        # 判断是 base64 字符串(API调用)还是文件路径(网页上传)
        if os.path.exists(file_b64):  # 是路径,直接用
            tmp_path = file_b64
            cleanup = False
        else:
            # 是base64字符串,写入临时文件
            suffix = ".mp4" if modal == "video" else ".png"
            with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
                tmp_file.write(base64.b64decode(file_b64))
                tmp_path = tmp_file.name
            cleanup = True

        try:
            # 处理输入,调用模型推理
            inputs = self.processor[modal](tmp_path)
            output = mm_infer(
                inputs,
                prompt,
                model=self.model,
                tokenizer=self.tokenizer,
                do_sample=False,
                modal=modal
            )
        finally:
            if cleanup:
                os.remove(tmp_path)

        # 返回结构统一,方便调用方解析
        return {
            "modal": modal,
            "prompt": prompt,
            "result": output
        }