fuuuzzy commited on
Commit
7c71fa7
·
verified ·
1 Parent(s): 6b80f9b

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -10,3 +10,8 @@ wheels/
10
  .venv
11
  output
12
  .cache
 
 
 
 
 
 
10
  .venv
11
  output
12
  .cache
13
+ .cursor
14
+ logs
15
+ models
16
+ data
17
+ pt-br
api.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 配音任务
2
+
3
+ - Request
4
+ - Path: ``
5
+ - Method: `post`
6
+ - Body:
7
+ ```json5
8
+ {
9
+ "character_voice": [
10
+ {
11
+ "character": "女主妈妈",//角色名
12
+ "id": "104982", //参考音频id
13
+ ""
14
+ "timbre_url": "https://xxx",//参考音频的地址
15
+ "timbre_text":""//参考音频文本
16
+ }
17
+ ],//参考角色音频信息
18
+ "content": [
19
+ {
20
+ "character": "女主妈妈",//角色名,跟character_voice对应上
21
+ "end": 0.9, //时间轴结束时间,时间格式的单位为秒 0.9则表示900毫秒,接收了请求处理的时候需要转换成00:00:00,100,标准的srt时间轴格式
22
+ "source": "你好", //原文本
23
+ "start": 14, //时间轴开始时间,时间格式的单位为秒,14则表示14秒,接收了请求处理的时候需要转换成00:00:00,100,标准的srt时间轴格式
24
+ "translation": "Hello" //需要生成语音的文本
25
+ }
26
+ ], //配音内容
27
+ "hook_url": "https://your-api.com/callback", //回调地址
28
+ "priority": 3, //优先级 1-5 最高5
29
+ "video_url": "https://example.com/video.mp4" //视频地址
30
+ }
31
+ ```
api.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 配置
4
+ APP_NAME="f5-tts-api"
5
+ PID_FILE="app.pid"
6
+ LOG_FILE="logs/startup.log"
7
+ PYTHON_CMD="uv run app.py"
8
+
9
+ # 获取当前脚本所在目录
10
+ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
11
+ cd "$SCRIPT_DIR"
12
+
13
+ # 确保 logs 目录存在
14
+ mkdir -p logs
15
+
16
+ start() {
17
+ if [ -f "$PID_FILE" ]; then
18
+ pid=$(cat "$PID_FILE")
19
+ if ps -p "$pid" > /dev/null; then
20
+ echo "$APP_NAME is already running (PID: $pid)"
21
+ return
22
+ else
23
+ echo "PID file exists but process is gone. Cleaning up."
24
+ rm "$PID_FILE"
25
+ fi
26
+ fi
27
+
28
+ echo "Starting $APP_NAME..."
29
+ nohup $PYTHON_CMD > "$LOG_FILE" 2>&1 &
30
+ pid=$!
31
+ echo "$pid" > "$PID_FILE"
32
+ echo "$APP_NAME started with PID $pid"
33
+ echo "Logs are being written to $LOG_FILE"
34
+ }
35
+
36
+ stop() {
37
+ if [ ! -f "$PID_FILE" ]; then
38
+ echo "$APP_NAME is not running (PID file not found)"
39
+ return
40
+ fi
41
+
42
+ pid=$(cat "$PID_FILE")
43
+ if ps -p "$pid" > /dev/null; then
44
+ echo "Stopping $APP_NAME (PID: $pid)..."
45
+ kill "$pid"
46
+ # 等待进程结束
47
+ count=0
48
+ while ps -p "$pid" > /dev/null; do
49
+ sleep 1
50
+ count=$((count + 1))
51
+ if [ "$count" -ge 10 ]; then
52
+ echo "Process did not stop after 10 seconds. Force killing..."
53
+ kill -9 "$pid"
54
+ break
55
+ fi
56
+ done
57
+ rm "$PID_FILE"
58
+ echo "$APP_NAME stopped"
59
+ else
60
+ echo "$APP_NAME is not running (Process not found)"
61
+ rm "$PID_FILE"
62
+ fi
63
+ }
64
+
65
+ restart() {
66
+ stop
67
+ sleep 2
68
+ start
69
+ }
70
+
71
+ status() {
72
+ if [ -f "$PID_FILE" ]; then
73
+ pid=$(cat "$PID_FILE")
74
+ if ps -p "$pid" > /dev/null; then
75
+ echo "$APP_NAME is running (PID: $pid)"
76
+ else
77
+ echo "$APP_NAME is stopped (PID file exists but process is gone)"
78
+ fi
79
+ else
80
+ echo "$APP_NAME is stopped"
81
+ fi
82
+ }
83
+
84
+ case "$1" in
85
+ start)
86
+ start
87
+ ;;
88
+ stop)
89
+ stop
90
+ ;;
91
+ restart)
92
+ restart
93
+ ;;
94
+ status)
95
+ status
96
+ ;;
97
+ *)
98
+ echo "Usage: $0 {start|stop|restart|status}"
99
+ exit 1
100
+ ;;
101
+ esac
102
+
103
+ exit 0
app.pid ADDED
@@ -0,0 +1 @@
 
 
1
+ 171729
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import threading
4
+ import time
5
+ from functools import wraps
6
+ import logging
7
+
8
+ import requests
9
+ import yaml
10
+ from flask import Flask, request, jsonify, g
11
+ from requests.adapters import HTTPAdapter
12
+ from urllib3.util.retry import Retry
13
+
14
+ from services.logger import get_app_logger, RequestLogger, task_id_var, get_process_worker_logger, \
15
+ get_upload_worker_logger
16
+ from services.tts_service import TTSService
17
+ from services.queue_manager import QueueManager
18
+ from services.r2_uploader import R2Uploader
19
+ from services.uvr5_service import UVR5Service
20
+ from services.merger_service import MergerService
21
+
22
+ logger = get_app_logger()
23
+
24
+ # 加载配置
25
+ def load_config(config_path='config.yaml'):
26
+ with open(config_path, 'r', encoding='utf-8') as f:
27
+ return yaml.safe_load(f)
28
+
29
+ config = load_config()
30
+
31
+ # Auth Helpers
32
+ def check_auth(username, password):
33
+ app_config = config.get('app', {})
34
+ return username == app_config['api_username'] and password == app_config['api_password']
35
+
36
+ def authenticate():
37
+ return jsonify({'error': 'Authentication required'}), 401, {'WWW-Authenticate': 'Basic realm="Login Required"'}
38
+
39
+ def requires_auth(f):
40
+ @wraps(f)
41
+ def decorated(*args, **kwargs):
42
+ auth = request.authorization
43
+ if not auth or not check_auth(auth.username, auth.password):
44
+ return authenticate()
45
+ return f(*args, **kwargs)
46
+ return decorated
47
+
48
+ def send_hook_with_retry(url: str, data: dict, max_retries: int = 3):
49
+ session = requests.Session()
50
+ retries = Retry(total=max_retries, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
51
+ session.mount('http://', HTTPAdapter(max_retries=retries))
52
+ session.mount('https://', HTTPAdapter(max_retries=retries))
53
+ try:
54
+ response = session.post(url, json=data, timeout=10)
55
+ response.raise_for_status()
56
+ return response
57
+ except Exception as e:
58
+ logger.error(f"Failed to send hook to {url}: {e}")
59
+ pass
60
+
61
+ def download_file(url: str, path: str):
62
+ response = requests.get(url, stream=True, timeout=60)
63
+ response.raise_for_status()
64
+ with open(path, 'wb') as f:
65
+ for chunk in response.iter_content(chunk_size=8192):
66
+ f.write(chunk)
67
+
68
+ # Initialize Flask
69
+ app = Flask(__name__)
70
+ app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
71
+
72
+ # Initialize Services
73
+ queue_manager = QueueManager(config['redis'])
74
+ r2_uploader = R2Uploader(config['r2'])
75
+ tts_service = TTSService(config)
76
+ uvr5_service = UVR5Service(config)
77
+ merger_service = MergerService(config)
78
+
79
+ # Temp Dir for Videos
80
+ VIDEO_TEMP_DIR = 'data/temp_videos'
81
+ os.makedirs(VIDEO_TEMP_DIR, exist_ok=True)
82
+
83
+ # -------------------------------------------------------------------------
84
+ # Workers
85
+ # -------------------------------------------------------------------------
86
+
87
+ def process_worker():
88
+ """
89
+ Main Pipeline Worker:
90
+ 1. Fetch Task
91
+ 2. Download Video
92
+ 3. Run TTS Generation
93
+ 4. Run UVR5 Separation (get BGM)
94
+ 5. Merge (Video + TTS + BGM)
95
+ 6. Push to Upload Queue
96
+ """
97
+ worker_logger = get_process_worker_logger()
98
+ worker_logger.info("Main Process Worker started")
99
+
100
+ while True:
101
+ try:
102
+ task = queue_manager.get_process_task()
103
+ if not task:
104
+ time.sleep(1)
105
+ continue
106
+
107
+ task_id = task.get('task_id')
108
+ token = task_id_var.set(task_id)
109
+
110
+ # Context variables for cleanup
111
+ local_video_path = None
112
+ bgm_path = None
113
+ vocals_path = None
114
+ task_tts_dir = None
115
+ final_output_path = None
116
+ success = False
117
+
118
+ try:
119
+ worker_logger.info("Processing started.")
120
+
121
+ # 1. Download Video
122
+ video_url = task['data'].get('video_url')
123
+ if not video_url:
124
+ raise ValueError("Missing video_url")
125
+
126
+ local_video_path = os.path.join(VIDEO_TEMP_DIR, f"{task_id}_input.mp4")
127
+ worker_logger.info(f"Downloading video from {video_url}")
128
+ download_file(video_url, local_video_path)
129
+
130
+ # 2. Run TTS
131
+ worker_logger.info("Running TTS...")
132
+ tts_result = tts_service.process_task(task)
133
+ segments = tts_result['segments']
134
+ task_tts_dir = tts_result['task_dir']
135
+
136
+ # 3. Run UVR5
137
+ worker_logger.info("Running UVR5 Separation...")
138
+ vocals_path, bgm_path = uvr5_service.process_audio(local_video_path, task_id)
139
+
140
+ if not bgm_path or not os.path.exists(bgm_path):
141
+ raise Exception("UVR5 failed to produce background music.")
142
+
143
+ # 4. Merge
144
+ worker_logger.info("Merging Audio and Video...")
145
+ final_output_path = os.path.join(VIDEO_TEMP_DIR, f"{task_id}_final.mp4")
146
+
147
+ merger_service.merge_video(
148
+ video_path=local_video_path,
149
+ bgm_path=bgm_path,
150
+ segments=segments,
151
+ output_path=final_output_path
152
+ )
153
+
154
+ # 5. Push to Upload
155
+ upload_task = {
156
+ 'task_id': task_id,
157
+ 'file_path': final_output_path,
158
+ 'hook_url': task['data'].get('hook_url'),
159
+ }
160
+ queue_manager.push_upload_task(upload_task)
161
+ success = True
162
+
163
+ except Exception as e:
164
+ worker_logger.error(f"Task processing failed: {e}", exc_info=True)
165
+
166
+ if 'hook_url' in task.get('data', {}):
167
+ hook_url = task['data']['hook_url']
168
+ failure_payload = {
169
+ "task_uuid": task_id,
170
+ "status": "failed",
171
+ "timestamp": int(time.time()),
172
+ "error_message": str(e)
173
+ }
174
+ send_hook_with_retry(hook_url, failure_payload)
175
+
176
+ finally:
177
+ # Cleanup Logic
178
+ try:
179
+ if local_video_path and os.path.exists(local_video_path):
180
+ os.remove(local_video_path)
181
+
182
+ if bgm_path and os.path.exists(bgm_path):
183
+ os.remove(bgm_path)
184
+
185
+ if vocals_path and os.path.exists(vocals_path):
186
+ os.remove(vocals_path)
187
+
188
+ if task_tts_dir and os.path.exists(task_tts_dir):
189
+ shutil.rmtree(task_tts_dir)
190
+
191
+ # Only delete final output if we FAILED.
192
+ # If success, upload worker handles it.
193
+ if not success and final_output_path and os.path.exists(final_output_path):
194
+ os.remove(final_output_path)
195
+
196
+ except Exception as cleanup_err:
197
+ worker_logger.warning(f"Cleanup error: {cleanup_err}")
198
+
199
+ task_id_var.reset(token)
200
+
201
+ except Exception as e:
202
+ worker_logger.error(f"Worker Loop Error: {e}")
203
+ time.sleep(5)
204
+
205
+ def upload_worker():
206
+ """
207
+ Upload Worker:
208
+ 1. Upload Final Video
209
+ 2. Send Success Callback
210
+ 3. Cleanup Final Video
211
+ """
212
+ worker_logger = get_upload_worker_logger()
213
+ worker_logger.info("Upload Worker started")
214
+
215
+ while True:
216
+ try:
217
+ result = queue_manager.get_upload_task(timeout=5)
218
+ if not result:
219
+ continue
220
+
221
+ task_id = result.get('task_id')
222
+ token = task_id_var.set(task_id)
223
+
224
+ file_path = result.get('file_path')
225
+ hook_url = result.get('hook_url')
226
+
227
+ try:
228
+ worker_logger.info(f"Uploading result: {file_path}")
229
+
230
+ file_url = None
231
+ if file_path and os.path.exists(file_path):
232
+ object_key = f"{task_id}.mp4"
233
+ file_url = r2_uploader.upload_file(file_path, object_key=object_key)
234
+ else:
235
+ raise FileNotFoundError(f"File to upload not found: {file_path}")
236
+
237
+ if hook_url:
238
+ success_payload = {
239
+ "task_uuid": task_id,
240
+ "status": "success",
241
+ "timestamp": int(time.time()),
242
+ "result_url": file_url
243
+ }
244
+ worker_logger.info(f"Sending success callback to {hook_url}")
245
+ send_hook_with_retry(hook_url, success_payload)
246
+
247
+ except Exception as e:
248
+ worker_logger.error(f"Upload failed: {e}", exc_info=True)
249
+ if hook_url:
250
+ failure_payload = {
251
+ "task_uuid": task_id,
252
+ "status": "failed",
253
+ "timestamp": int(time.time()),
254
+ "error_message": str(e)
255
+ }
256
+ send_hook_with_retry(hook_url, failure_payload)
257
+ finally:
258
+ # Cleanup the final video file
259
+ if file_path and os.path.exists(file_path):
260
+ try:
261
+ os.remove(file_path)
262
+ worker_logger.info(f"Removed final video: {file_path}")
263
+ except Exception as e:
264
+ worker_logger.warning(f"Failed to remove file: {e}")
265
+
266
+ task_id_var.reset(token)
267
+
268
+ except Exception as e:
269
+ logger.error(f"Upload Loop Error: {e}")
270
+ time.sleep(5)
271
+
272
+ # -------------------------------------------------------------------------
273
+ # Flask Routes
274
+ # -------------------------------------------------------------------------
275
+
276
+ @app.before_request
277
+ def before_request():
278
+ g.start_time = time.time()
279
+
280
+ @app.after_request
281
+ def after_request(response):
282
+ if hasattr(g, 'start_time'):
283
+ duration = time.time() - g.start_time
284
+ RequestLogger.log_request(request, response, duration)
285
+ return response
286
+
287
+ @app.route('/dubbing/character', methods=['POST'])
288
+ @requires_auth
289
+ def generate():
290
+ try:
291
+ data = request.json
292
+ # Basic Validation
293
+ required = ['character_voice', 'content', 'hook_url', 'video_url']
294
+ for field in required:
295
+ if not data.get(field):
296
+ return jsonify({'error': f'Missing field: {field}'}), 400
297
+
298
+ priority = data.get('priority', 3)
299
+ if priority not in range(1, 6):
300
+ return jsonify({'error': 'Priority must be 1-5'}), 400
301
+
302
+ task_id = queue_manager.add_task(data, priority)
303
+ logger.info(f"Created Task: {task_id}")
304
+
305
+ return jsonify({
306
+ 'task_uuid': task_id,
307
+ 'status': 'queued',
308
+ 'message': 'Task queued successfully'
309
+ }), 201
310
+
311
+ except Exception as e:
312
+ logger.error(f"API Error: {e}")
313
+ return jsonify({'error': str(e)}), 500
314
+
315
+ @app.route('/dubbing/character/tasks/<task_id>/cancel', methods=['DELETE'])
316
+ @requires_auth
317
+ def cancel_task(task_id: str):
318
+ try:
319
+ if queue_manager.delete_process_task(task_id):
320
+ return jsonify({'message': 'Task canceled'}), 200
321
+ return jsonify({'message': 'Task not found or already processed'}), 404
322
+ except Exception as e:
323
+ return jsonify({'error': str(e)}), 500
324
+
325
+ @app.errorhandler(500)
326
+ def internal_error(error):
327
+ logger.error(f"500 Error: {error}")
328
+ return jsonify({'error': 'Internal server error'}), 500
329
+
330
+ def main():
331
+ logger.info("Starting Service...")
332
+
333
+ # Directories
334
+ os.makedirs(config['tts']['output_dir'], exist_ok=True)
335
+ os.makedirs(config['tts']['voices_dir'], exist_ok=True)
336
+ os.makedirs(VIDEO_TEMP_DIR, exist_ok=True)
337
+
338
+ # Threads
339
+ threading.Thread(target=process_worker, daemon=True).start()
340
+ threading.Thread(target=upload_worker, daemon=True).start()
341
+
342
+ app.run(
343
+ host=config['app']['host'],
344
+ port=config['app']['port'],
345
+ debug=config['app']['debug']
346
+ )
347
+
348
+ if __name__ == '__main__':
349
+ main()
config.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app:
2
+ host: '0.0.0.0'
3
+ port: 8000
4
+ debug: false
5
+ api_username: admin
6
+ api_password: admin
7
+
8
+ # Redis 配置
9
+ redis:
10
+ host: 'localhost'
11
+ port: 6379
12
+ db: 0
13
+ password: null
14
+ queue_key: 'tts:generate'
15
+ queue_key_hash: 'tts:generate_hash'
16
+ upload_queue_key: 'tts:upload'
17
+ max_connections: 5
18
+
19
+ # Cloudflare R2 配置
20
+ r2:
21
+ access_key_id: '2c4cef629ca75ffe03376206c0a3e365'
22
+ secret_access_key: '42cb6c0dedd621bbe2a38eb52c5d4b4738d69038705020c8cd14018dcc30ee53'
23
+ bucket_name: 'ls-tts'
24
+ endpoint_url: 'https://3322fcf6693dc79f8e04aa2f4918bc44.r2.cloudflarestorage.com'
25
+ public_url: 'https://tts.luckyshort.net'
26
+
27
+ # TTS 服务配置
28
+ tts:
29
+ checkpoint_file: 'pt-br/model_last.safetensors' # 模型文件路径
30
+ vocab_file: 'vocab.txt' # 词表文件 (如果需要)
31
+ vocoder_name: 'vocos' # 默认 vocoder
32
+ remove_silence: true
33
+ speed: 1.0
34
+ device: 'cuda' # 'cuda' or 'cpu'
35
+ # 路径配置
36
+ voices_dir: 'data/voices' # 参考音频缓存目录
37
+ output_dir: 'data/outputs' # 生成结果临时目录
38
+
39
+
40
+ uvr5:
41
+ model_dir: './models/uvr5'
42
+ output_dir: './temp/uvr5'
43
+ uvr5_model: 'UVR-MDX-NET-Inst_HQ_4' # UVR5 模型名称
merger.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 音频自动合并脚本 - 腾讯云 TTS 克隆音频
4
+ 根据音频参数,将多个克隆音频和 BGM 混合并压制到视频中
5
+
6
+ 核心功能:
7
+ 1. 智能音频处理策略(填充/直接覆盖/提速)
8
+ 2. 防爆音优化(淡入淡出、压缩、限幅)
9
+ 3. BGM 背景音乐混合
10
+ 4. 链式 atempo 处理(突破 FFmpeg 0.5-2.0 限制)
11
+ 5. 音频压制到视频
12
+ """
13
+
14
+ import logging
15
+ import math
16
+ import os
17
+ import subprocess
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Optional
20
+
21
+ # 使用 process_worker 的 logger
22
+ logger = logging.getLogger('process_worker')
23
+
24
+ # ============================================================================
25
+ # 常量定义
26
+ # ============================================================================
27
+
28
+ SAFETY_MARGIN = 0.01 # 安全间隙,单位秒
29
+ FADE_DURATION = 0.15 # 淡入淡出时长,单位秒
30
+ VOLUME_LEVEL = 0.95 # 预降音量级别
31
+ COMPRESSOR_THRESHOLD = -12 # 压缩器阈值(dB)
32
+ COMPRESSOR_RATIO = 4 # 压缩比
33
+ LIMITER_LEVEL = 0.95 # 限幅器级别
34
+ MAX_SPEED_RATIO = 4.0 # 最大加速倍数,防止极端加速
35
+
36
+
37
+ # ============================================================================
38
+ # 数据类定义
39
+ # ============================================================================
40
+
41
+ @dataclass
42
+ class AudioParam:
43
+ """音频参数"""
44
+ start_secs: float # 开始秒(必填)
45
+ end_secs: float # 结束秒(必填)
46
+ clone_audio_path: str # 克隆后音频地址(必填)
47
+ original_audio_length: float # 原始音频长度(必填)
48
+ clone_audio_length: float # 克隆后音频长度(必填)
49
+ audio_sort_num: int # 音频序号(必填)
50
+
51
+ def __post_init__(self):
52
+ """验证参数"""
53
+ if not self.clone_audio_path:
54
+ raise ValueError("clone_audio_path 不能为空")
55
+ if not os.path.exists(self.clone_audio_path):
56
+ raise FileNotFoundError(f"音频文件不存在: {self.clone_audio_path}")
57
+ if self.start_secs < 0:
58
+ raise ValueError(f"start_secs 必须非负,实际值: {self.start_secs}")
59
+ if self.end_secs <= self.start_secs:
60
+ raise ValueError(f"end_secs 必须大于 start_secs,start_secs: {self.start_secs}, end_secs: {self.end_secs}")
61
+ if self.original_audio_length <= 0:
62
+ raise ValueError(f"original_audio_length 必须大于0,实际值: {self.original_audio_length}")
63
+ if self.clone_audio_length <= 0:
64
+ raise ValueError(f"clone_audio_length 必须大于0,实际值: {self.clone_audio_length}")
65
+ if self.audio_sort_num < 0:
66
+ raise ValueError(f"audio_sort_num 必须非负,实际值: {self.audio_sort_num}")
67
+
68
+
69
+ @dataclass
70
+ class AudioMerge:
71
+ """音频合并参数"""
72
+ output_path: str # 输出路径(必填)
73
+ bgm_path: str # bgm音频路径(必填)
74
+ input_path: str # 输入路径(必填)
75
+ input_type: str = "video" # audio, video
76
+ speed_strategy: str = "max" # 音频策略:max(默认),mix,normal(可选)
77
+ audio_params: List[AudioParam] = None # AudioParam数组(必填)
78
+
79
+ def __post_init__(self):
80
+ """验证参数"""
81
+ if not self.output_path:
82
+ raise ValueError("output_path 不能为空")
83
+ if not self.bgm_path:
84
+ raise ValueError("bgm_path 不能为空")
85
+ if not os.path.exists(self.bgm_path):
86
+ raise FileNotFoundError(f"BGM文件不存在: {self.bgm_path}")
87
+ if not self.input_path:
88
+ raise ValueError("input_path 不能为空")
89
+ if not os.path.exists(self.input_path):
90
+ raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
91
+ # 校验输出路径和输入路径必须不同
92
+ output_abs = os.path.abspath(self.output_path)
93
+ input_abs = os.path.abspath(self.input_path)
94
+ if output_abs == input_abs:
95
+ raise ValueError(f"output_path 和 input_path 不能相同: {output_abs}")
96
+ if not self.audio_params or len(self.audio_params) == 0:
97
+ raise ValueError("audio_params 不能为空")
98
+ if self.speed_strategy not in ["mix", "normal", "max"]:
99
+ raise ValueError(f"speed_strategy 必须是 mix/normal/max 之一,实际值: {self.speed_strategy}")
100
+ # 按序号排序
101
+ self.audio_params = sorted(self.audio_params, key=lambda x: x.audio_sort_num)
102
+
103
+
104
+ # ============================================================================
105
+ # 工具函数
106
+ # ============================================================================
107
+
108
+ def get_audio_duration(audio_path: str) -> float:
109
+ """使用 ffprobe 获取音频文件的时长"""
110
+ cmd = [
111
+ 'ffprobe', '-v', 'error',
112
+ '-show_entries', 'format=duration',
113
+ '-of', 'default=noprint_wrappers=1:nokey=1',
114
+ audio_path
115
+ ]
116
+ try:
117
+ result = subprocess.check_output(
118
+ cmd,
119
+ stderr=subprocess.STDOUT,
120
+ timeout=30 # 30 秒超时
121
+ )
122
+ return float(result.decode().strip())
123
+ except subprocess.TimeoutExpired:
124
+ raise Exception(f"获取音频时长超时: {audio_path}")
125
+ except subprocess.CalledProcessError as e:
126
+ error_output = e.output.decode() if e.output else "未知错误"
127
+ raise Exception(f"获取音频时长失败: {audio_path}\n{error_output}")
128
+
129
+
130
+ def build_atempo_chain(speed_ratio: float) -> str:
131
+ """构建 atempo 滤镜链,处理超出 [0.5, 2.0] 范围的速度调整"""
132
+ if speed_ratio == 1.0:
133
+ return ""
134
+ if 0.5 <= speed_ratio <= 2.0:
135
+ return f"atempo={speed_ratio:.6f},"
136
+ if speed_ratio < 0.5:
137
+ stages = int(math.ceil(math.log(speed_ratio) / math.log(0.5)))
138
+ final_ratio = speed_ratio / (0.5 ** (stages - 1))
139
+ return "atempo=0.5," * (stages - 1) + f"atempo={final_ratio:.6f},"
140
+ stages = int(math.ceil(math.log(speed_ratio) / math.log(2.0)))
141
+ final_ratio = speed_ratio / (2.0 ** (stages - 1))
142
+ return "atempo=2.0," * (stages - 1) + f"atempo={final_ratio:.6f},"
143
+
144
+
145
+ # ============================================================================
146
+ # 音频策略计算
147
+ # ============================================================================
148
+
149
+ def calculate_audio_strategy(
150
+ audio_duration: float,
151
+ srt_duration: float,
152
+ next_gap: Optional[float],
153
+ speed_strategy: str = 'max',
154
+ start_time: float = 0.0,
155
+ end_time: float = 0.0
156
+ ) -> Dict:
157
+ """计算音频处理策略"""
158
+ next_gap_val = next_gap if next_gap is not None else float('inf')
159
+
160
+ if speed_strategy == 'mix':
161
+ clone_ratio = audio_duration / srt_duration if srt_duration > 0 else 0
162
+ description = (
163
+ f'[mix] 保持原音 | 原始: {srt_duration:.3f}s | 克隆: {audio_duration:.3f}s ({clone_ratio:.3f}x) | 处理后: {audio_duration:.3f}s | '
164
+ f'速度: {1.0:.3f}x (克隆/处理后 = {audio_duration:.3f}/{audio_duration:.3f}) | '
165
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s | 超出部分会混音'
166
+ )
167
+ return {
168
+ 'strategy': 'direct',
169
+ 'speed_ratio': 1.0,
170
+ 'target_duration': audio_duration,
171
+ 'actual_duration': audio_duration,
172
+ 'description': description
173
+ }
174
+
175
+ if speed_strategy == 'normal':
176
+ target_dur = srt_duration + SAFETY_MARGIN
177
+ if audio_duration <= target_dur:
178
+ clone_ratio = audio_duration / srt_duration if srt_duration > 0 else 0
179
+ description = (
180
+ f'[normal] 直接使用 | 原始: {srt_duration:.3f}s | 克隆: {audio_duration:.3f}s ({clone_ratio:.3f}x) | 处理后: {audio_duration:.3f}s | '
181
+ f'速度: {1.0:.3f}x (克隆/处理后 = {audio_duration:.3f}/{audio_duration:.3f}) | '
182
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s | 未超出字幕时长'
183
+ )
184
+ return {
185
+ 'strategy': 'direct',
186
+ 'speed_ratio': 1.0,
187
+ 'target_duration': audio_duration,
188
+ 'actual_duration': audio_duration,
189
+ 'description': description
190
+ }
191
+ speed_ratio = audio_duration / target_dur
192
+ # 限制最大加速倍数为4倍
193
+ if speed_ratio > MAX_SPEED_RATIO:
194
+ original_target_dur = target_dur
195
+ original_speed_ratio = speed_ratio
196
+ logger.warning(
197
+ f'⚠️ 加速倍数超过限制 | 原始加速: {original_speed_ratio:.3f}x | '
198
+ f'已限制为: {MAX_SPEED_RATIO}x | 音频时长: {audio_duration:.3f}s | '
199
+ f'目标时长: {original_target_dur:.3f}s -> {audio_duration / MAX_SPEED_RATIO:.3f}s | '
200
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s'
201
+ )
202
+ speed_ratio = MAX_SPEED_RATIO
203
+ target_dur = audio_duration / MAX_SPEED_RATIO
204
+ clone_ratio = audio_duration / srt_duration if srt_duration > 0 else 0
205
+ description = (
206
+ f'[normal] 提速到结束 | 原始: {srt_duration:.3f}s | 克隆: {audio_duration:.3f}s ({clone_ratio:.3f}x) | 处理后: {target_dur:.3f}s | '
207
+ f'速度: {speed_ratio:.3f}x (克隆/处理后 = {audio_duration:.3f}/{target_dur:.3f}) | '
208
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s'
209
+ )
210
+ return {
211
+ 'strategy': 'speedup',
212
+ 'speed_ratio': speed_ratio,
213
+ 'target_duration': target_dur,
214
+ 'actual_duration': audio_duration,
215
+ 'description': description
216
+ }
217
+
218
+ if speed_strategy == 'max':
219
+ max_available_dur = srt_duration + next_gap_val
220
+ if audio_duration <= max_available_dur:
221
+ clone_ratio = audio_duration / srt_duration if srt_duration > 0 else 0
222
+ description = (
223
+ f'[max] 直接使用 | 原始: {srt_duration:.3f}s | 克隆: {audio_duration:.3f}s ({clone_ratio:.3f}x) | 处理后: {audio_duration:.3f}s | '
224
+ f'速度: {1.0:.3f}x (克隆/��理后 = {audio_duration:.3f}/{audio_duration:.3f}) | '
225
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s | 间隙: {next_gap_val:.3f}s'
226
+ )
227
+ return {
228
+ 'strategy': 'direct',
229
+ 'speed_ratio': 1.0,
230
+ 'target_duration': audio_duration,
231
+ 'actual_duration': audio_duration,
232
+ 'description': description
233
+ }
234
+ target_dur = max_available_dur - SAFETY_MARGIN
235
+ speed_ratio = audio_duration / target_dur
236
+ # 限制最大加速倍数为4倍
237
+ if speed_ratio > MAX_SPEED_RATIO:
238
+ original_target_dur = target_dur
239
+ original_speed_ratio = speed_ratio
240
+ logger.warning(
241
+ f'⚠️ 加速倍数超过限制 | 原始加速: {original_speed_ratio:.3f}x | '
242
+ f'已限制为: {MAX_SPEED_RATIO}x | 音频时长: {audio_duration:.3f}s | '
243
+ f'目标时长: {original_target_dur:.3f}s -> {audio_duration / MAX_SPEED_RATIO:.3f}s | '
244
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s'
245
+ )
246
+ speed_ratio = MAX_SPEED_RATIO
247
+ target_dur = audio_duration / MAX_SPEED_RATIO
248
+ clone_ratio = audio_duration / srt_duration if srt_duration > 0 else 0
249
+ description = (
250
+ f'[max] 提速到下个 | 原始: {srt_duration:.3f}s | 克隆: {audio_duration:.3f}s ({clone_ratio:.3f}x) | 处理后: {target_dur:.3f}s | '
251
+ f'速度: {speed_ratio:.3f}x (克隆/处理后 = {audio_duration:.3f}/{target_dur:.3f}) | '
252
+ f'时间轴: {start_time:.3f}s -> {end_time:.3f}s | 间隙: {next_gap_val:.3f}s'
253
+ )
254
+ return {
255
+ 'strategy': 'speedup',
256
+ 'speed_ratio': speed_ratio,
257
+ 'target_duration': target_dur,
258
+ 'actual_duration': audio_duration,
259
+ 'description': description
260
+ }
261
+
262
+ return calculate_audio_strategy(audio_duration, srt_duration, next_gap, 'normal', start_time, end_time)
263
+
264
+
265
+ def analyze_audio_tracks(
266
+ audio_params: List[AudioParam],
267
+ speed_strategy: str = 'max',
268
+ task_logger=None
269
+ ) -> List[Dict]:
270
+ """分析音频轨道,计算处理策略
271
+
272
+ 使用传入的 start_secs 和 end_secs 计算时间轴和间隙
273
+ """
274
+ # 使用传入的 logger 或默认的
275
+ log = task_logger or logger
276
+
277
+ tracks = []
278
+
279
+ for idx, param in enumerate(audio_params):
280
+ # 使用传入的 clone_audio_length(已在 __post_init__ 中验证)
281
+ audio_duration = param.clone_audio_length
282
+
283
+ # 使用 original_audio_length 作为字幕时长(SRT duration)
284
+ srt_duration = param.original_audio_length
285
+
286
+ # 使用传入的 start_secs 和 end_secs
287
+ start_time = param.start_secs
288
+ end_time = param.end_secs
289
+
290
+ # 计算到下个音频的间隙
291
+ next_gap = None
292
+ if idx < len(audio_params) - 1:
293
+ # 当前音频的结束时间
294
+ current_end_time = end_time
295
+ # 下一个音频的开始时间
296
+ next_param = audio_params[idx + 1]
297
+ next_start_time = next_param.start_secs
298
+ # 计算真实间隙:下一个音频开始时间 - 当前音频结束时间
299
+ # 如果连续排列,gap = 0;如果有间隙,gap > 0;如果重叠,gap < 0
300
+ next_gap = next_start_time - current_end_time
301
+
302
+ # 计算处理策略
303
+ # 对于最后一个音频,如果使用 max 策略,回退到 normal 策略(避免 infinity 导致 speed_ratio = 0)
304
+ effective_strategy = speed_strategy
305
+ is_last_track = (idx == len(audio_params) - 1)
306
+ if is_last_track and speed_strategy == 'max':
307
+ effective_strategy = 'normal'
308
+
309
+ strategy = calculate_audio_strategy(
310
+ audio_duration,
311
+ srt_duration,
312
+ next_gap,
313
+ effective_strategy,
314
+ start_time,
315
+ end_time
316
+ )
317
+
318
+ tracks.append({
319
+ 'id': param.audio_sort_num,
320
+ 'audio_file': param.clone_audio_path,
321
+ 'start_time': start_time,
322
+ 'end_time': end_time,
323
+ 'srt_duration': srt_duration,
324
+ 'audio_duration': audio_duration,
325
+ 'next_gap': next_gap,
326
+ 'strategy': strategy,
327
+ 'param': param
328
+ })
329
+
330
+ log.info(f" → 音频 [{param.audio_sort_num:03d}]: {strategy['description']}")
331
+
332
+ return tracks
333
+
334
+
335
+ # ============================================================================
336
+ # FFmpeg Filter Complex 构建
337
+ # ============================================================================
338
+
339
+ def build_filter_complex_for_video(
340
+ audio_tracks: List[Dict],
341
+ has_bgm: bool
342
+ ) -> str:
343
+ """构建 FFmpeg filter_complex 字符串(包含视频压制)"""
344
+ filters = []
345
+
346
+ # 1. 处理每个克隆音频
347
+ for idx, track in enumerate(audio_tracks):
348
+ input_idx = idx + 1 # 输入索引:[0:视频] [1:音频1] [2:音频2] ...
349
+ audio_label = f"a{idx}"
350
+ strategy = track['strategy']
351
+
352
+ speed_ratio = strategy['speed_ratio']
353
+ target_duration = strategy['target_duration']
354
+ start_time = track['start_time']
355
+
356
+ # 构建 atempo 链
357
+ atempo_chain = build_atempo_chain(speed_ratio)
358
+
359
+ # 计算安全的淡入淡出时长
360
+ safe_fade_dur = min(FADE_DURATION, target_duration / 2.0)
361
+
362
+ # 构建滤镜:变速 → 裁剪 → 重置PTS → 降音量 → 淡入淡出 → 延迟
363
+ audio_filter = (
364
+ f"[{input_idx}:a]"
365
+ f"{atempo_chain}" # 变速(如需要)
366
+ f"atrim=start=0:end={target_duration:.3f}," # 裁剪到目标时长
367
+ f"asetpts=PTS-STARTPTS," # 重置时间戳
368
+ f"volume={VOLUME_LEVEL}," # 预降音量
369
+ f"afade=t=in:st=0:d={safe_fade_dur:.3f}:curve=esin," # 淡入
370
+ f"afade=t=out:st={max(0.0, target_duration - safe_fade_dur):.3f}:d={safe_fade_dur:.3f}:curve=esin," # 淡出
371
+ f"adelay={int(start_time * 1000)}|{int(start_time * 1000)}" # 延迟对齐(最后一个滤镜,不需要逗号)
372
+ f"[{audio_label}]"
373
+ )
374
+ filters.append(audio_filter)
375
+
376
+ # 2. 处理 BGM
377
+ if has_bgm:
378
+ bgm_input_idx = len(audio_tracks) + 1 # BGM 在最后一个输入
379
+ bgm_filter = f"[{bgm_input_idx}:a]volume=1.0[bgm]"
380
+ filters.append(bgm_filter)
381
+
382
+ # 3. 混音
383
+ audio_labels = "".join([f"[a{i}]" for i in range(len(audio_tracks))])
384
+ if has_bgm:
385
+ audio_labels += "[bgm]"
386
+ mix_input_count = len(audio_tracks) + 1
387
+ else:
388
+ mix_input_count = len(audio_tracks)
389
+
390
+ mix_filter = (
391
+ f"{audio_labels}"
392
+ f"amix=inputs={mix_input_count}:duration=longest:normalize=0[mixed]"
393
+ )
394
+ filters.append(mix_filter)
395
+
396
+ # 4. 动态处理:压缩器 + 限幅器
397
+ dynamics_filter = (
398
+ f"[mixed]"
399
+ f"acompressor=threshold={COMPRESSOR_THRESHOLD}dB:ratio={COMPRESSOR_RATIO}:attack=5:release=50,"
400
+ f"alimiter=limit={LIMITER_LEVEL}"
401
+ f"[mixout]"
402
+ )
403
+ filters.append(dynamics_filter)
404
+
405
+ # 5. 视频流(直接映射,不处理字幕)
406
+ # 注意:视频流不走 filter,直接映射 0:v
407
+ # 在命令行中使用 -map 0:v 而不是 -map [vout]
408
+
409
+ # 过滤掉空字符串,避免产生空的滤镜
410
+ filters = [f for f in filters if f and f.strip()]
411
+ return ";".join(filters)
412
+
413
+ def build_filter_complex_for_audio(
414
+ audio_tracks: List[Dict],
415
+ has_bgm: bool
416
+ ) -> str:
417
+ """
418
+ 构建 FFmpeg filter_complex 字符串
419
+
420
+ 处理流程:
421
+ 1. 每个音频:变速(如需要)→ 裁剪 → 重置时间戳 → 降音量 → 淡入淡出 → 延迟对齐
422
+ 2. BGM:调整音量
423
+ 3. 混音:amix
424
+ 4. 动态处理:压缩器 + 限幅器
425
+
426
+ Args:
427
+ audio_tracks: 准备好的音频轨道列表
428
+ has_bgm: 是否有 BGM 音轨
429
+
430
+ Returns:
431
+ filter_complex 字符串
432
+ """
433
+ filters = []
434
+
435
+ # 1. 处理每个克隆音频
436
+ for idx, track in enumerate(audio_tracks):
437
+ input_idx = idx # 输入索引从 0 开始(没有视频输入)
438
+ audio_label = f"a{idx}"
439
+ strategy = track['strategy']
440
+
441
+ speed_ratio = strategy['speed_ratio']
442
+ target_duration = strategy['target_duration']
443
+ start_time = track['start_time']
444
+
445
+ # 构建 atempo 链
446
+ atempo_chain = build_atempo_chain(speed_ratio)
447
+
448
+ # 计算安全的淡入淡出时长(不超过音频时长的一半)
449
+ safe_fade_dur = min(FADE_DURATION, target_duration / 2.0)
450
+
451
+ # 构建滤镜:变速 → 裁剪 → 重置PTS → 降音量 → 淡入淡出 → 延迟
452
+ audio_filter = (
453
+ f"[{input_idx}:a]"
454
+ f"{atempo_chain}" # 变速(如需要)
455
+ f"atrim=start=0:end={target_duration:.3f}," # 裁剪到目标时长
456
+ f"asetpts=PTS-STARTPTS," # 重置时间戳
457
+ f"volume={VOLUME_LEVEL}," # 预降音量
458
+ f"afade=t=in:st=0:d={safe_fade_dur:.3f}:curve=esin," # 淡入
459
+ f"afade=t=out:st={max(0.0, target_duration - safe_fade_dur):.3f}:d={safe_fade_dur:.3f}:curve=esin," # 淡出
460
+ f"adelay={int(start_time * 1000)}|{int(start_time * 1000)}" # 延迟对齐
461
+ f"[{audio_label}]"
462
+ )
463
+ filters.append(audio_filter)
464
+
465
+ # 2. 处理 BGM(如果有)
466
+ if has_bgm:
467
+ bgm_input_idx = len(audio_tracks) # BGM 在最后一个输入
468
+ bgm_filter = f"[{bgm_input_idx}:a]volume=1.0[bgm]"
469
+ filters.append(bgm_filter)
470
+
471
+ # 3. 混音
472
+ audio_labels = "".join([f"[a{i}]" for i in range(len(audio_tracks))])
473
+ if has_bgm:
474
+ audio_labels += "[bgm]"
475
+ mix_input_count = len(audio_tracks) + 1
476
+ else:
477
+ mix_input_count = len(audio_tracks)
478
+
479
+ mix_filter = (
480
+ f"{audio_labels}"
481
+ f"amix=inputs={mix_input_count}:duration=longest:normalize=0[mixed]"
482
+ )
483
+ filters.append(mix_filter)
484
+
485
+ # 4. 动���处理:压缩器 + 限幅器
486
+ dynamics_filter = (
487
+ f"[mixed]"
488
+ f"acompressor=threshold={COMPRESSOR_THRESHOLD}dB:ratio={COMPRESSOR_RATIO}:attack=5:release=50,"
489
+ f"alimiter=limit={LIMITER_LEVEL}"
490
+ f"[out]"
491
+ )
492
+ filters.append(dynamics_filter)
493
+
494
+ # 过滤掉空字符串,避免产生空的滤镜
495
+ filters = [f for f in filters if f and f.strip()]
496
+ return ";".join(filters)
497
+
498
+ # ============================================================================
499
+ # 主函数
500
+ # ============================================================================
501
+
502
+ def audio_auto_merge(audio_merge: AudioMerge, task_logger=None) -> Dict:
503
+ """
504
+ 音频自动合并函数
505
+
506
+ 根据 AudioMerge 参数,将多个克隆音频和 BGM 混合并压制到视频中
507
+
508
+ Args:
509
+ audio_merge: 音频合并参数类
510
+ task_logger: 带task_id的logger(可选)
511
+
512
+ Returns:
513
+ 结果字典,包含 success、output_file 等
514
+ """
515
+ # 使用传入的 logger 或默认的
516
+ log = task_logger or logger
517
+
518
+ log.info(f"开始音频合并 (策略: {audio_merge.speed_strategy})")
519
+
520
+ # 验证输入文件(静默)
521
+ if not os.path.exists(audio_merge.input_path):
522
+ raise FileNotFoundError(f"输入文件不存在: {audio_merge.input_path}")
523
+ if not os.path.exists(audio_merge.bgm_path):
524
+ raise FileNotFoundError(f"BGM文件不存在: {audio_merge.bgm_path}")
525
+
526
+ bgm_duration = get_audio_duration(audio_merge.bgm_path)
527
+ log.debug(f"BGM 时长: {bgm_duration:.2f}s")
528
+
529
+ # 分析音频轨道
530
+ log.info(f"分析 {len(audio_merge.audio_params)} 个音频轨道...")
531
+ audio_tracks = analyze_audio_tracks(audio_merge.audio_params, audio_merge.speed_strategy, log)
532
+
533
+ # 构建 filter_complex
534
+ log.debug(f"构建 FFmpeg 滤镜...")
535
+ if audio_merge.input_type == 'audio':
536
+ filter_complex = build_filter_complex_for_audio(audio_tracks, True) # 总是有 BGM
537
+ else:
538
+ filter_complex = build_filter_complex_for_video(audio_tracks, True) # 总是有 BGM
539
+ log.debug(f"滤镜长度: {len(filter_complex)} 字符")
540
+
541
+ # 4. 构建 FFmpeg 命令
542
+ ffmpeg_cmd = ['ffmpeg', '-nostdin']
543
+
544
+ # 添加输入文件:视频 + 音频 + BGM
545
+ if audio_merge.input_type == "video":
546
+ ffmpeg_cmd.extend(['-i', audio_merge.input_path])
547
+ for track in audio_tracks:
548
+ ffmpeg_cmd.extend(['-i', track['audio_file']])
549
+ ffmpeg_cmd.extend(['-i', audio_merge.bgm_path])
550
+
551
+ if audio_merge.input_type == "audio":
552
+ ffmpeg_cmd.extend([
553
+ '-filter_complex', filter_complex,
554
+ '-map', '[out]',
555
+ '-c:a', 'pcm_s16le', # WAV 格式使用 PCM 编码
556
+ '-ar', '44100', # 采样率 44.1kHz
557
+ '-ac', '2', # 双声道
558
+ '-y',
559
+ audio_merge.output_path
560
+ ])
561
+ else:
562
+ # 添加滤镜和输出设置
563
+ ffmpeg_cmd.extend([
564
+ '-filter_complex', filter_complex,
565
+ '-map', '0:v', # 直接映射原始视频流(不走 filter)
566
+ '-map', '[mixout]', # 映射混合后的音频
567
+ '-c:v', 'copy', # 视频流复制,不重新编码
568
+ '-movflags', '+faststart',
569
+ '-c:a', 'aac', # 音频编码为 AAC
570
+ '-b:a', '128k', # 音频比特率
571
+ '-avoid_negative_ts', '1',
572
+ '-f', 'mp4',
573
+ '-y',
574
+ audio_merge.output_path
575
+ ])
576
+
577
+ # 执行 FFmpeg
578
+ log.info(f"执行音频混合和视频合成...")
579
+ log.debug(f"FFmpeg 命令: {' '.join(ffmpeg_cmd)}")
580
+
581
+ process = None
582
+ try:
583
+ # 实时输出 FFmpeg 日志(FFmpeg 输出到 stderr,合并到 stdout)
584
+ process = subprocess.Popen(
585
+ ffmpeg_cmd,
586
+ stdout=subprocess.PIPE,
587
+ stderr=subprocess.STDOUT, # 将 stderr 重定向到 stdout
588
+ universal_newlines=True,
589
+ bufsize=1
590
+ )
591
+
592
+ # 实时打印输出(仅 DEBUG 级别)
593
+ try:
594
+ for line in process.stdout:
595
+ log.debug(f"FFmpeg: {line.rstrip()}")
596
+ finally:
597
+ # 确保 stdout 被关闭
598
+ if process.stdout and not process.stdout.closed:
599
+ process.stdout.close()
600
+
601
+ # 等待进程完成,设置超时(30 分钟)
602
+ try:
603
+ process.wait(timeout=1800)
604
+ except subprocess.TimeoutExpired:
605
+ log.error(f"FFmpeg 执行超时(30分钟),强制终止进程")
606
+ process.kill()
607
+ process.wait()
608
+ raise Exception("FFmpeg 执行超时(30分钟)")
609
+
610
+ if process.returncode != 0:
611
+ raise subprocess.CalledProcessError(process.returncode, ffmpeg_cmd)
612
+
613
+ # 6. 验证输出
614
+ if not os.path.exists(audio_merge.output_path):
615
+ raise Exception("输出文件未生成")
616
+
617
+ file_size = os.path.getsize(audio_merge.output_path)
618
+ if file_size < 1024:
619
+ raise Exception(f"输出文件异常(大小: {file_size} bytes)")
620
+
621
+ log.info(
622
+ f"✓ 音频合并完成: {os.path.basename(audio_merge.output_path)} ({file_size / 1024 / 1024:.2f} MB, {len(audio_tracks)} 轨道)")
623
+
624
+ return {
625
+ 'output_file': audio_merge.output_path,
626
+ 'file_size': file_size,
627
+ 'track_count': len(audio_tracks),
628
+ 'has_bgm': True
629
+ }
630
+
631
+ except subprocess.CalledProcessError as e:
632
+ error_msg = f"FFmpeg 执行失败,返回码: {e.returncode}"
633
+ log.error(f"❌ {error_msg}")
634
+ raise Exception(error_msg)
635
+ except Exception as e:
636
+ log.error(f"❌ 音频合并失败: {e}")
637
+ raise
638
+ finally:
639
+ # 确保子进程被清理
640
+ if process is not None:
641
+ try:
642
+ # 如果进程还在运行,强制终止
643
+ if process.poll() is None:
644
+ log.warning(f"清理残留 FFmpeg 进程...")
645
+ try:
646
+ process.kill()
647
+ process.wait(timeout=5)
648
+ except subprocess.TimeoutExpired:
649
+ log.error(f"FFmpeg 进程无法终止,可能需要手动清理")
650
+ except Exception as cleanup_error:
651
+ log.error(f" ⚠️ 清理进程时出错: {cleanup_error}")
652
+ finally:
653
+ # 确保 stdout 被关闭
654
+ if process.stdout and not process.stdout.closed:
655
+ try:
656
+ process.stdout.close()
657
+ except:
658
+ pass
pyproject.toml CHANGED
@@ -1,11 +1,19 @@
1
  [project]
2
  name = "f5-tts-pt-br"
3
  version = "0.1.0"
4
- description = "Add your description here"
5
  readme = "README.md"
6
- requires-python = ">=3.11"
7
  dependencies = [
8
  "f5-tts>=1.1.10",
9
- "torch>=2.9.1",
10
  "tqdm>=4.67.1",
 
 
 
 
 
 
 
 
11
  ]
 
1
  [project]
2
  name = "f5-tts-pt-br"
3
  version = "0.1.0"
4
+ description = "F5-TTS Voice Cloning API with UVR5 and Audio Merging"
5
  readme = "README.md"
6
+ requires-python = ">=3.10"
7
  dependencies = [
8
  "f5-tts>=1.1.10",
9
+ "torch>=2.1.0",
10
  "tqdm>=4.67.1",
11
+ "flask>=3.0.0",
12
+ "redis>=5.0.0",
13
+ "requests>=2.31.0",
14
+ "pyyaml>=6.0.0",
15
+ "boto3>=1.34.0",
16
+ "audio-separator[gpu]>=0.17.0",
17
+ "onnxruntime-gpu>=1.17.0",
18
+ "ffmpeg-python>=0.2.0"
19
  ]
services/logger.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import contextvars
4
+ from logging.handlers import RotatingFileHandler
5
+
6
+ # 日志配置常量
7
+ BASE_LOG_DIR = 'logs'
8
+ MAX_BYTES = 10 * 1024 * 1024 # 10MB
9
+ BACKUP_COUNT = 5
10
+
11
+ # ContextVar for task_id
12
+ task_id_var = contextvars.ContextVar("task_id", default=None)
13
+
14
+
15
+ class TaskIdFormatter(logging.Formatter):
16
+ """
17
+ Custom Formatter that injects task_id if available, or 'N/A' if not.
18
+ Does NOT rely on Filter injection, making it robust for third-party loggers.
19
+ """
20
+ def format(self, record):
21
+ # 1. Try to get task_id from ContextVar
22
+ task_id = task_id_var.get()
23
+
24
+ # 2. If not in ContextVar, check if it's already in record (e.g. passed via extra)
25
+ if not hasattr(record, 'task_id'):
26
+ record.task_id = task_id if task_id else 'N/A'
27
+ elif record.task_id is None:
28
+ # Ensure it's not None if attribute exists
29
+ record.task_id = 'N/A'
30
+
31
+ return super().format(record)
32
+
33
+
34
+ def setup_logging(service_name: str, level=logging.INFO):
35
+ """
36
+ 集中配置日志系统,每个服务使用单独的日志文件。
37
+
38
+ Args:
39
+ service_name (str): 服务的名称(例如 'app' 或 'worker'),用于命名 logger 和日志文件。
40
+ level (int): 日志级别。
41
+ """
42
+ # 确保日志目录存在
43
+ os.makedirs(BASE_LOG_DIR, exist_ok=True)
44
+
45
+ # 1. 确定日志文件路径
46
+ log_file_name = f'{service_name}.log'
47
+ log_file_path = os.path.join(BASE_LOG_DIR, log_file_name)
48
+
49
+ # 2. 获取 logger 实例
50
+ logger = logging.getLogger(service_name)
51
+ logger.setLevel(level)
52
+
53
+ # 3. 使用自定义格式化器
54
+ # 注意:这里我们移除了 Filter,改用 Formatter 处理
55
+ formatter = TaskIdFormatter(
56
+ '%(asctime)s - [%(task_id)s] - %(name)s - %(levelname)s - %(message)s'
57
+ )
58
+
59
+ # 4. 控制台处理器 (StreamHandler)
60
+ console_handler = logging.StreamHandler()
61
+ console_handler.setFormatter(formatter)
62
+
63
+ # 5. 文件处理器 (RotatingFileHandler)
64
+ file_handler = RotatingFileHandler(
65
+ log_file_path,
66
+ maxBytes=MAX_BYTES,
67
+ backupCount=BACKUP_COUNT,
68
+ encoding='utf-8'
69
+ )
70
+ file_handler.setFormatter(formatter)
71
+
72
+ # 6. 配置 root logger 或 propagation
73
+ # 获取 log_file_path 的绝对路径进行比较
74
+ abs_log_file_path = os.path.abspath(log_file_path)
75
+
76
+ # Helper function to add handlers if not present
77
+ def attach_handlers_to(target_logger_name):
78
+ target = logging.getLogger(target_logger_name)
79
+ target.setLevel(level)
80
+
81
+ # 检查是否已经添加了对应的 FileHandler
82
+ has_file_handler = any(
83
+ isinstance(h, RotatingFileHandler) and
84
+ os.path.abspath(h.baseFilename) == abs_log_file_path
85
+ for h in target.handlers
86
+ )
87
+
88
+ if not has_file_handler:
89
+ target.addHandler(file_handler)
90
+ # 避免控制台重复
91
+ if not any(isinstance(h, logging.StreamHandler) for h in target.handlers):
92
+ target.addHandler(console_handler)
93
+
94
+ # 针对 worker 进程,配置 services 和 ls_ocr logger
95
+ attach_handlers_to('services')
96
+ # 虽然已经移除 OCR,但保留机制以防其他 module 使用 services logger
97
+
98
+ # 避免重复添加 handlers 到主 logger
99
+ attach_handlers_to(service_name)
100
+
101
+ return logger
102
+
103
+
104
+ # 辅助函数,用于简化调用
105
+ def get_app_logger():
106
+ return setup_logging(service_name='app')
107
+
108
+
109
+ def get_process_worker_logger():
110
+ return setup_logging(service_name='process_worker')
111
+
112
+
113
+ def get_upload_worker_logger():
114
+ return setup_logging(service_name='upload_worker')
115
+
116
+
117
+ class RequestLogger:
118
+ """请求日志记录器(用于Flask)"""
119
+
120
+ @staticmethod
121
+ def log_request(request, response, duration: float = None):
122
+ """
123
+ 记录HTTP请求日志
124
+
125
+ Args:
126
+ request: Flask request对象
127
+ response: Flask response对象
128
+ duration: 请求处理时间(秒)
129
+ """
130
+ logger = get_app_logger()
131
+
132
+ # 构建日志消息
133
+ msg_parts = [
134
+ f"{request.method} {request.path}",
135
+ f"status={response.status_code}",
136
+ ]
137
+
138
+ if duration is not None:
139
+ msg_parts.append(f"duration={duration:.3f}s")
140
+
141
+ # 添加查询参数
142
+ if request.query_string:
143
+ msg_parts.append(f"query={request.query_string.decode('utf-8')}")
144
+
145
+ # 添加客户端IP
146
+ client_ip = request.headers.get('X-Forwarded-For', request.remote_addr)
147
+ msg_parts.append(f"ip={client_ip}")
148
+
149
+ msg = " | ".join(msg_parts)
150
+
151
+ # 根据状态码选择日志级别
152
+ if response.status_code >= 500:
153
+ logger.error(msg)
154
+ elif response.status_code >= 400:
155
+ logger.warning(msg)
156
+ else:
157
+ logger.info(msg)
158
+
159
+ @staticmethod
160
+ def log_error(request, error: Exception):
161
+ """
162
+ 记录错误日志
163
+
164
+ Args:
165
+ request: Flask request对象
166
+ error: 异常对象
167
+ """
168
+ logger = get_app_logger()
169
+ logger.exception(
170
+ f"请求错误 | {request.method} {request.path} | "
171
+ f"error={type(error).__name__}: {str(error)}"
172
+ )
services/merger_service.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import shutil
4
+ from typing import List, Dict, Any
5
+ from merger import audio_auto_merge, AudioMerge, AudioParam
6
+
7
+ logger = logging.getLogger("services.merger")
8
+
9
+ class MergerService:
10
+ def __init__(self, config: Dict[str, Any] = None):
11
+ pass
12
+
13
+ def merge_video(self,
14
+ video_path: str,
15
+ bgm_path: str,
16
+ segments: List[Dict[str, Any]],
17
+ output_path: str) -> str:
18
+ """
19
+ Merge TTS segments, BGM and Original Video.
20
+ """
21
+ try:
22
+ logger.info(f"Preparing to merge. Video: {video_path}, BGM: {bgm_path}")
23
+
24
+ # 1. Convert dictionary segments to AudioParam objects
25
+ audio_params = []
26
+ for seg in segments:
27
+ param = AudioParam(
28
+ start_secs=float(seg['start_time']),
29
+ end_secs=float(seg['end_time']),
30
+ clone_audio_path=seg['path'],
31
+ original_audio_length=float(seg['original_duration']),
32
+ clone_audio_length=float(seg['gen_duration']),
33
+ audio_sort_num=seg['index']
34
+ )
35
+ audio_params.append(param)
36
+
37
+ if not audio_params:
38
+ raise ValueError("No valid audio segments to merge.")
39
+
40
+ # 2. Create Merge Config
41
+ merge_config = AudioMerge(
42
+ output_path=output_path,
43
+ bgm_path=bgm_path,
44
+ input_path=video_path,
45
+ input_type="video",
46
+ speed_strategy="max", # Default strategy
47
+ audio_params=audio_params
48
+ )
49
+
50
+ # 3. Call existing merger logic
51
+ result = audio_auto_merge(merge_config, task_logger=logger)
52
+
53
+ return result['output_file']
54
+
55
+ except Exception as e:
56
+ logger.error(f"Merge failed: {e}")
57
+ raise e
services/queue_manager.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ import uuid
5
+ from typing import Dict, Optional, Any
6
+
7
+ import redis
8
+ from redis import ConnectionPool
9
+
10
+ logger = logging.getLogger("services")
11
+
12
+
13
+ def _calculate_score(priority: int) -> float:
14
+ """
15
+ 计算任务得分 (优先级)
16
+ """
17
+ timestamp = time.time()
18
+ score = (6 - priority) * timestamp
19
+ return score
20
+
21
+
22
+ class QueueManager:
23
+ """Redis 队列管理器"""
24
+
25
+ def __init__(self, redis_config: Dict[str, Any]):
26
+ """
27
+ 初始化队列管理器
28
+
29
+ Args:
30
+ redis_config: Redis 配置字典
31
+ """
32
+ # 提取 Redis 连接参数
33
+ host = redis_config['host']
34
+ port = redis_config['port']
35
+ db = redis_config['db']
36
+ password = redis_config.get('password')
37
+
38
+ max_connections = redis_config.get('max_connections', 10)
39
+
40
+ # 1. 创建 Redis 连接池
41
+ self.pool = ConnectionPool(
42
+ host=host,
43
+ port=port,
44
+ db=db,
45
+ password=password,
46
+ max_connections=max_connections,
47
+ decode_responses=True
48
+ )
49
+
50
+ self.redis_client = redis.Redis(connection_pool=self.pool)
51
+
52
+ self.process_queue_key = redis_config['queue_key']
53
+
54
+ self.process_hash_queue_key = redis_config['queue_key'] + '_hash'
55
+
56
+ self.upload_queue_key = redis_config['upload_queue_key']
57
+
58
+ logger.info(
59
+ f"QueueManager initialized with Connection Pool (Max={max_connections}): "
60
+ f"Process Queue={self.process_queue_key}, Upload Queue={self.upload_queue_key}")
61
+
62
+ try:
63
+ self.redis_client.ping()
64
+ logger.info("Redis connection successful.")
65
+ except Exception as e:
66
+ logger.error(f"Failed to connect to Redis: {e}")
67
+ raise
68
+
69
+ def add_task(self, task_data: Dict[str, Any], priority: int = 3) -> str:
70
+ """
71
+ 添加任务到处理队列
72
+
73
+ Args:
74
+ task_data: 任务数据 (包含 hook_url, text, etc.)
75
+ priority: 优先级 (1-5),1 为最高优先级
76
+
77
+ Returns:
78
+ str: 任务 ID
79
+ """
80
+ task_id = str(uuid.uuid4())
81
+
82
+ task = {
83
+ 'task_id': task_id,
84
+ 'priority': priority,
85
+ 'created_at': time.time(),
86
+ 'data': task_data
87
+ }
88
+
89
+ score = _calculate_score(priority)
90
+
91
+ # 管道操作,一次网络请求完成两件事(原子性 + 性能高)
92
+ pipe = self.redis_client.pipeline()
93
+
94
+ # 1. ZSET:用于优先级排序和弹出
95
+ pipe.zadd(self.process_queue_key, {task_id: score})
96
+
97
+ # 2. HASH:专门存 uuid -> 完整 JSON(支持快速读取/修改/删除)
98
+ pipe.hset(self.process_hash_queue_key, task_id, json.dumps(task))
99
+
100
+ # 执行
101
+ pipe.execute()
102
+
103
+ logger.info(f"Task {task_id} added to process queue with priority {priority}, score {score}")
104
+
105
+ return task_id
106
+
107
+ def get_process_task(self) -> Optional[Dict[str, Any]]:
108
+ """
109
+ 从处理队列 (ZSET) 中获取下一个任务
110
+
111
+ Returns:
112
+ Optional[Dict]: 任务数据,如果队列为空返回 None
113
+ """
114
+ # 使用 ZPOPMAX 获取得分最高的任务
115
+ result = self.redis_client.zpopmax(self.process_queue_key, 1)
116
+
117
+ if not result:
118
+ return None
119
+
120
+ task_id, _ = result[0]
121
+
122
+ pipe = self.redis_client.pipeline()
123
+ pipe.hget(self.process_hash_queue_key, task_id)
124
+ pipe.hdel(self.process_hash_queue_key, task_id)
125
+
126
+ task_json, _ = pipe.execute()
127
+
128
+ task = json.loads(task_json)
129
+
130
+ logger.info(f"Task {task.get('task_id', 'Unknown')} retrieved from process queue.")
131
+ return task
132
+
133
+ def push_upload_task(self, task_result: Dict[str, Any]):
134
+ """
135
+ 将处理结果推送到上传队列 (List)
136
+
137
+ Args:
138
+ task_result: 任务处理结果 (包含 task_id, output_paths, hook_url等)
139
+ """
140
+ self.redis_client.lpush(self.upload_queue_key, json.dumps(task_result))
141
+ logger.info(f"Task {task_result['task_id']} pushed to upload queue.")
142
+
143
+ def get_upload_task(self, timeout: int = 5) -> Optional[Dict[str, Any]]:
144
+ """
145
+ 从上传队列 (List) 中获取任务,阻塞等待
146
+
147
+ Args:
148
+ timeout: 阻塞等待时间 (秒)
149
+
150
+ Returns:
151
+ Optional[Dict]: 任务数据,如果超时返回 None
152
+ """
153
+ # 使用 BRPOP 阻塞弹出
154
+ result = self.redis_client.brpop(self.upload_queue_key, timeout)
155
+
156
+ if not result:
157
+ return None
158
+
159
+ # BRPOP 返回 (key, value)
160
+ task_json = result[1]
161
+ task = json.loads(task_json)
162
+
163
+ logger.debug(f"Task {task.get('task_id', 'Unknown')} retrieved from upload queue.")
164
+ return task
165
+
166
+ def get_process_queue_stats(self) -> Dict[str, Any]:
167
+ """
168
+ 获取处理队列统计信息
169
+ """
170
+ queued_count = self.redis_client.zcard(self.process_queue_key)
171
+ upload_count = self.redis_client.llen(self.upload_queue_key)
172
+
173
+ return {
174
+ 'process_queued': queued_count,
175
+ 'upload_queued': upload_count,
176
+ 'timestamp': time.time()
177
+ }
178
+
179
+ def delete_process_task(self, task_id: str) -> bool:
180
+ """
181
+ 根据 task_id 安全、快速、原子地删除任务(推荐生产写法)
182
+ """
183
+ pipe = self.redis_client.pipeline()
184
+
185
+ # 1. 先查 HASH 中是否存在(O(1))
186
+ pipe.hget(self.process_hash_queue_key, task_id)
187
+
188
+ # 2. 同时从 ZSET 删除(即使 HASH 已不存在也能删干净)
189
+ pipe.zrem(self.process_queue_key, task_id)
190
+
191
+ json_str, zrem_count = pipe.execute()
192
+
193
+ if json_str is not None:
194
+ # 任务存在于 HASH,说明之前没被消费,真正需要删除
195
+ self.redis_client.hdel(self.process_hash_queue_key, task_id)
196
+ logger.warning(f"Task {task_id} successfully removed from queue (cancelled).")
197
+ return True
198
+
199
+ if zrem_count > 0:
200
+ # 任务可能已经被消费了,但 ZSET 里还有残留(异常情况),也算清理成功
201
+ logger.info(f"Task {task_id} only existed in ZSET (stale), cleaned up.")
202
+ return True
203
+
204
+ logger.info(f"Task {task_id} not found in queue.")
205
+ return False
services/r2_uploader.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Optional
5
+
6
+ import boto3
7
+ from botocore.config import Config
8
+ from botocore.exceptions import ClientError, NoCredentialsError
9
+
10
+ logger = logging.getLogger("services")
11
+
12
+
13
+ def _get_content_type(file_path: str) -> Optional[str]:
14
+ """
15
+ 根据文件扩展名获取内容类型
16
+
17
+ Args:
18
+ file_path: 文件路径
19
+
20
+ Returns:
21
+ Optional[str]: 内容类型
22
+ """
23
+ extension = Path(file_path).suffix.lower()
24
+
25
+ content_types = {
26
+ '.wav': 'audio/wav',
27
+ '.mp3': 'audio/mpeg',
28
+ '.mp4': 'video/mp4',
29
+ '.ogg': 'audio/ogg',
30
+ '.flac': 'audio/flac',
31
+ '.aac': 'audio/aac',
32
+ '.m4a': 'audio/mp4',
33
+ '.srt': 'text/plain; charset=utf-8',
34
+ '.txt': 'text/plain; charset=utf-8',
35
+ '.json': 'application/json; charset=utf-8',
36
+ '.zip': 'application/zip',
37
+ '.png': 'image/png',
38
+ '.jpg': 'image/jpeg',
39
+ '.jpeg': 'image/jpeg',
40
+ '.gif': 'image/gif',
41
+ }
42
+
43
+ return content_types.get(extension, 'application/octet-stream') # 默认二进制
44
+
45
+
46
+ class R2Uploader:
47
+ """Cloudflare R2 上传器"""
48
+
49
+ def __init__(self, r2_config: Dict[str, Any]):
50
+ """
51
+ 初始化 R2 上传器
52
+
53
+ Args:
54
+ r2_config: R2 配置字典,包含 'bucket_name', 'endpoint_url', 'access_key_id', 'secret_access_key', 'public_url' 等
55
+
56
+ Raises:
57
+ ValueError: 配置无效
58
+ ClientError: 凭证或 Bucket 访问失败
59
+ """
60
+ self.config = r2_config
61
+ self.bucket_name = r2_config['bucket_name']
62
+ self.public_url = r2_config.get('public_url', '').rstrip('/')
63
+
64
+ # 验证必需配置
65
+ required_keys = ['bucket_name', 'endpoint_url', 'access_key_id', 'secret_access_key']
66
+ missing = [k for k in required_keys if k not in r2_config]
67
+ if missing:
68
+ raise ValueError(f"Missing required R2 config keys: {missing}")
69
+
70
+ client_config = Config(
71
+ signature_version='s3v4', # R2 必需的签名版本
72
+ retries={
73
+ 'max_attempts': 3,
74
+ 'mode': 'standard'
75
+ },
76
+ connect_timeout=10,
77
+ read_timeout=10
78
+ )
79
+
80
+ try:
81
+ self.s3_client = boto3.client(
82
+ 's3',
83
+ endpoint_url=r2_config['endpoint_url'],
84
+ aws_access_key_id=r2_config['access_key_id'],
85
+ aws_secret_access_key=r2_config['secret_access_key'],
86
+ config=client_config
87
+ )
88
+ except NoCredentialsError:
89
+ raise ValueError("Invalid AWS credentials (Access Key/Secret Key)")
90
+
91
+ self._validate_bucket_access()
92
+
93
+ logger.info(f"R2Uploader initialized for bucket: {self.bucket_name}")
94
+
95
+ def _validate_bucket_access(self):
96
+ """验证 Bucket 访问权限(测试 PutObject 权限模拟)"""
97
+ try:
98
+ # 先检查 Bucket 存在和基本访问
99
+ self.s3_client.head_bucket(Bucket=self.bucket_name)
100
+ logger.debug(f"Bucket '{self.bucket_name}' access confirmed")
101
+
102
+ # 简单测试:尝试列出对象(如果权限不足,会早抛 AccessDenied)
103
+ self.s3_client.list_objects_v2(Bucket=self.bucket_name, MaxKeys=0)
104
+
105
+ except ClientError as e:
106
+ error_code = e.response['Error']['Code']
107
+ error_msg = e.response['Error']['Message']
108
+ if error_code == 'AccessDenied':
109
+ raise ClientError(
110
+ e.response,
111
+ "AccessDenied: Check API Token permissions (requires 'Object Read & Write' for bucket). "
112
+ f"Ensure token is bound to bucket '{self.bucket_name}'. Details: {error_msg}"
113
+ )
114
+ elif error_code == 'NoSuchBucket':
115
+ raise ClientError(
116
+ e.response,
117
+ f"Bucket '{self.bucket_name}' does not exist or is not accessible."
118
+ )
119
+ else:
120
+ raise
121
+
122
+ def upload_file(
123
+ self,
124
+ file_path: str,
125
+ object_key: Optional[str] = None,
126
+ metadata: Optional[Dict[str, str]] = None
127
+ ) -> str:
128
+ """
129
+ 上传文件到 R2
130
+
131
+ Args:
132
+ file_path: 本地文件路径
133
+ object_key: R2 对象键(路径),如果为 None 则使用文件名
134
+ metadata: 文件元数据(注意:R2 支持基本 Metadata,但不支持 Tagging)
135
+
136
+ Returns:
137
+ str: 文件的公开 URL
138
+
139
+ Raises:
140
+ FileNotFoundError: 文件不存在
141
+ ClientError: 上传失败(包含详细错误信息)
142
+ """
143
+ if not os.path.exists(file_path):
144
+ raise FileNotFoundError(f"File not found: {file_path}")
145
+
146
+ # 如果没有指定 object_key,使用文件名
147
+ if object_key is None:
148
+ object_key = os.path.basename(file_path)
149
+
150
+ # 确保 object_key 不为空且无前导 /
151
+ object_key = object_key.lstrip('/')
152
+
153
+ try:
154
+ # 准备上传参数(避免不支持参数,如 ACL)
155
+ extra_args = {}
156
+
157
+ # 设置内容类型
158
+ content_type = _get_content_type(file_path)
159
+ if content_type:
160
+ extra_args['ContentType'] = content_type
161
+
162
+ # 设置元数据(R2 支持)
163
+ if metadata:
164
+ extra_args['Metadata'] = metadata
165
+
166
+ # 上传文件
167
+ logger.info(f"Uploading {file_path} to R2 bucket '{self.bucket_name}' as '{object_key}'")
168
+
169
+ self.s3_client.upload_file(
170
+ file_path,
171
+ self.bucket_name,
172
+ object_key,
173
+ ExtraArgs=extra_args
174
+ )
175
+
176
+ # 生成公开 URL
177
+ file_url = f"{self.public_url}/{object_key}" if self.public_url else None
178
+
179
+ logger.info(f"File uploaded successfully: {file_url or object_key}")
180
+
181
+ return file_url or object_key
182
+
183
+ except ClientError as e:
184
+ error_code = e.response['Error']['Code']
185
+ error_msg = e.response['Error']['Message']
186
+ logger.error(
187
+ f"Failed to upload {file_path} to R2: Code={error_code}, Message={error_msg}. "
188
+ f"Check: 1) API Token has 'Object Read & Write' permission bound to bucket '{self.bucket_name}'. "
189
+ f"2) No unsupported params (e.g., ACL, Tagging). 3) Endpoint/Region correct."
190
+ )
191
+ raise ClientError(e.response, f"Upload failed: {error_msg}")
192
+ except Exception as e:
193
+ logger.error(f"Unexpected error uploading {file_path}: {str(e)}")
194
+ raise
195
+
196
+ def upload_files(
197
+ self,
198
+ file_paths: list,
199
+ prefix: str = '',
200
+ metadata: Optional[Dict[str, str]] = None
201
+ ) -> Dict[str, str]:
202
+ """
203
+ 批量上传文件到 R2
204
+
205
+ Args:
206
+ file_paths: 本地文件路径列表
207
+ prefix: R2 路径前缀(会自动添加 / 如果需要)
208
+ metadata: 文件元数据
209
+
210
+ Returns:
211
+ Dict[str, str]: 文件路径到 URL 的映射(失败为 None)
212
+ """
213
+ results = {}
214
+ prefix = prefix.lstrip('/') # 清理前导 /
215
+
216
+ for file_path in file_paths:
217
+ try:
218
+ # 生成 object_key
219
+ filename = os.path.basename(file_path)
220
+ object_key = f"{prefix}/{filename}" if prefix else filename
221
+
222
+ # 上传文件
223
+ file_url = self.upload_file(file_path, object_key, metadata)
224
+ results[file_path] = file_url
225
+
226
+ except Exception as e:
227
+ logger.error(f"Failed to upload {file_path}: {str(e)}")
228
+ results[file_path] = None
229
+
230
+ successful = sum(1 for url in results.values() if url is not None)
231
+ logger.info(f"Batch upload completed: {successful}/{len(file_paths)} files uploaded to '{self.bucket_name}'")
232
+
233
+ return results
234
+
235
+ def delete_file(self, object_key: str):
236
+ """
237
+ 从 R2 删除文件
238
+
239
+ Args:
240
+ object_key: R2 对象键
241
+
242
+ Raises:
243
+ ClientError: 删除失败
244
+ """
245
+ object_key = object_key.lstrip('/')
246
+ try:
247
+ logger.info(f"Deleting '{object_key}' from R2 bucket '{self.bucket_name}'")
248
+
249
+ self.s3_client.delete_object(
250
+ Bucket=self.bucket_name,
251
+ Key=object_key
252
+ )
253
+
254
+ logger.info(f"File deleted successfully: {object_key}")
255
+
256
+ except ClientError as e:
257
+ error_code = e.response['Error']['Code']
258
+ error_msg = e.response['Error']['Message']
259
+ logger.error(f"Failed to delete {object_key}: Code={error_code}, Message={error_msg}")
260
+ raise
261
+
262
+ def delete_files(self, object_keys: list):
263
+ """
264
+ 批量删除文件
265
+
266
+ Args:
267
+ object_keys: R2 对象键列表
268
+ """
269
+ if not object_keys:
270
+ return
271
+
272
+ # 清理键
273
+ cleaned_keys = [k.lstrip('/') for k in object_keys if k.strip()]
274
+
275
+ try:
276
+ # 准备删除对象列表
277
+ delete_objects = [{'Key': key} for key in cleaned_keys]
278
+
279
+ logger.info(f"Deleting {len(cleaned_keys)} files from R2 bucket '{self.bucket_name}'")
280
+
281
+ response = self.s3_client.delete_objects(
282
+ Bucket=self.bucket_name,
283
+ Delete={'Objects': delete_objects}
284
+ )
285
+
286
+ deleted_count = len(response.get('Deleted', []))
287
+ errors = response.get('Errors', [])
288
+ if errors:
289
+ logger.warning(f"Batch delete errors: {errors}")
290
+
291
+ logger.info(f"Batch delete completed: {deleted_count}/{len(cleaned_keys)} files deleted")
292
+
293
+ except ClientError as e:
294
+ error_code = e.response['Error']['Code']
295
+ error_msg = e.response['Error']['Message']
296
+ logger.error(f"Failed to batch delete files: Code={error_code}, Message={error_msg}")
297
+ raise
298
+
299
+ def file_exists(self, object_key: str) -> bool:
300
+ """
301
+ 检查文件是否存在
302
+
303
+ Args:
304
+ object_key: R2 对象键
305
+
306
+ Returns:
307
+ bool: 文件是否存在
308
+ """
309
+ object_key = object_key.lstrip('/')
310
+ try:
311
+ self.s3_client.head_object(
312
+ Bucket=self.bucket_name,
313
+ Key=object_key
314
+ )
315
+ return True
316
+
317
+ except ClientError as e:
318
+ if e.response['Error']['Code'] == '404':
319
+ return False
320
+ else:
321
+ error_msg = e.response['Error']['Message']
322
+ logger.error(f"Error checking existence of {object_key}: {error_msg}")
323
+ raise
324
+
325
+ def get_file_url(self, object_key: str) -> str:
326
+ """
327
+ 获取文件的公开 URL
328
+
329
+ Args:
330
+ object_key: R2 对象键
331
+
332
+ Returns:
333
+ str: 文件 URL
334
+ """
335
+ object_key = object_key.lstrip('/')
336
+ return f"{self.public_url}/{object_key}" if self.public_url else object_key
337
+
338
+ def list_files(self, prefix: str = '', max_keys: int = 1000) -> list:
339
+ """
340
+ 列出 R2 中的文件
341
+
342
+ Args:
343
+ prefix: 路径前缀
344
+ max_keys: 最大返回数量
345
+
346
+ Returns:
347
+ list: 文件列表 [{'key': str, 'size': int, 'last_modified': str, 'url': str}]
348
+ """
349
+ prefix = prefix.lstrip('/')
350
+ try:
351
+ response = self.s3_client.list_objects_v2(
352
+ Bucket=self.bucket_name,
353
+ Prefix=prefix,
354
+ MaxKeys=max_keys
355
+ )
356
+
357
+ files = []
358
+ if 'Contents' in response:
359
+ for obj in response['Contents']:
360
+ files.append({
361
+ 'key': obj['Key'],
362
+ 'size': obj['Size'],
363
+ 'last_modified': obj['LastModified'].isoformat(),
364
+ 'url': self.get_file_url(obj['Key'])
365
+ })
366
+
367
+ logger.info(f"Listed {len(files)} files with prefix '{prefix}' in '{self.bucket_name}'")
368
+
369
+ return files
370
+
371
+ except ClientError as e:
372
+ error_code = e.response['Error']['Code']
373
+ error_msg = e.response['Error']['Message']
374
+ logger.error(f"Failed to list files: Code={error_code}, Message={error_msg}")
375
+ raise
services/tts_service.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import logging
4
+ import shutil
5
+ import subprocess
6
+ from urllib.parse import urlparse
7
+ from typing import List, Dict, Any, Optional
8
+ from AgentF5TTSChunk import AgentF5TTS
9
+
10
+ logger = logging.getLogger("services.tts")
11
+
12
+ def get_audio_duration(file_path: str) -> float:
13
+ """Get duration of audio file using ffprobe."""
14
+ try:
15
+ cmd = [
16
+ 'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
17
+ '-of', 'default=noprint_wrappers=1:nokey=1', file_path
18
+ ]
19
+ output = subprocess.check_output(cmd).decode().strip()
20
+ return float(output)
21
+ except Exception as e:
22
+ logger.error(f"Failed to get duration for {file_path}: {e}")
23
+ return 0.0
24
+
25
+ class TTSService:
26
+ def __init__(self, config: Dict[str, Any]):
27
+ self.config = config['tts']
28
+ self.voices_dir = self.config['voices_dir']
29
+ self.output_dir = self.config['output_dir']
30
+
31
+ # Ensure directories exist
32
+ os.makedirs(self.voices_dir, exist_ok=True)
33
+ os.makedirs(self.output_dir, exist_ok=True)
34
+
35
+ # Load Model
36
+ logger.info("Loading F5-TTS Model...")
37
+ try:
38
+ self.agent = AgentF5TTS(
39
+ ckpt_file=self.config['checkpoint_file'],
40
+ device=self.config.get('device', 'cuda')
41
+ )
42
+ logger.info("F5-TTS Model Loaded successfully.")
43
+ except Exception as e:
44
+ logger.error(f"Failed to load F5-TTS Model: {e}")
45
+ raise e
46
+
47
+ def _get_extension_from_url(self, url: str) -> str:
48
+ parsed = urlparse(url)
49
+ path = parsed.path
50
+ ext = os.path.splitext(path)[1]
51
+ if not ext:
52
+ return ".wav"
53
+ return ext
54
+
55
+ def _download_file(self, url: str, path: str):
56
+ response = requests.get(url, stream=True, timeout=30)
57
+ response.raise_for_status()
58
+ with open(path, 'wb') as f:
59
+ for chunk in response.iter_content(chunk_size=8192):
60
+ f.write(chunk)
61
+
62
+ def prepare_voices(self, character_voices: List[Dict[str, str]]) -> Dict[str, Dict[str, str]]:
63
+ """
64
+ Ensure all reference voices are available locally.
65
+ Returns a map of character_name -> {'path': local_file_path, 'text': ref_text}
66
+ """
67
+ voice_map = {}
68
+
69
+ for cv in character_voices:
70
+ char_name = cv.get('character')
71
+ voice_id = cv.get('id')
72
+ url = cv.get('timbre_url') # Updated from character_url
73
+ text = cv.get('timbre_text', "") # New field
74
+
75
+ if not voice_id:
76
+ continue
77
+
78
+ # Use ID as filename to avoid duplicates
79
+ ext = ".wav"
80
+ if url:
81
+ ext = self._get_extension_from_url(url)
82
+
83
+ filename = f"{voice_id}{ext}"
84
+ local_path = os.path.join(self.voices_dir, filename)
85
+
86
+ # Download if not exists
87
+ if not os.path.exists(local_path):
88
+ if url:
89
+ try:
90
+ logger.info(f"Downloading voice {voice_id}")
91
+ self._download_file(url, local_path)
92
+ except Exception as e:
93
+ logger.error(f"Failed to download voice {voice_id}: {e}")
94
+ continue
95
+ else:
96
+ logger.warning(f"Voice {voice_id} missing locally and no URL.")
97
+ continue
98
+
99
+ if os.path.exists(local_path):
100
+ voice_data = {'path': local_path, 'text': text}
101
+ if char_name:
102
+ voice_map[char_name] = voice_data
103
+ voice_map[str(voice_id)] = voice_data
104
+
105
+ return voice_map
106
+
107
+ def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
108
+ """
109
+ Process a TTS generation task.
110
+ Returns dictionary containing list of generated audio segments with metadata.
111
+ """
112
+ task_id = task['task_id']
113
+ data = task['data']
114
+
115
+ character_voices = data.get('character_voice', [])
116
+ content = data.get('content', [])
117
+
118
+ if not content:
119
+ raise ValueError("No content provided.")
120
+
121
+ # 1. Prepare Voices
122
+ voice_map = self.prepare_voices(character_voices)
123
+
124
+ # 2. Create Task Output Directory
125
+ task_out_dir = os.path.join(self.output_dir, task_id)
126
+ os.makedirs(task_out_dir, exist_ok=True)
127
+
128
+ segments_metadata = []
129
+
130
+ # 3. Inference Loop
131
+ logger.info(f"Starting inference for {len(content)} segments")
132
+
133
+ for idx, segment in enumerate(content):
134
+ char_name = segment.get('character')
135
+ text = segment.get('translation')
136
+ start_time = segment.get('start', 0.0)
137
+ end_time = segment.get('end', 0.0)
138
+
139
+ if not text:
140
+ continue
141
+
142
+ # Calculate original duration for merger
143
+ original_duration = max(0.0, end_time - start_time)
144
+
145
+ voice_data = voice_map.get(char_name)
146
+
147
+ if not voice_data:
148
+ logger.warning(f"Segment {idx}: No voice for '{char_name}'. Skipping.")
149
+ continue
150
+
151
+ ref_audio_path = voice_data['path']
152
+ ref_audio_text = voice_data['text']
153
+
154
+ out_filename = f"{idx:04d}.wav"
155
+ out_path = os.path.join(task_out_dir, out_filename)
156
+
157
+ try:
158
+ self.agent.infer(
159
+ ref_file=ref_audio_path,
160
+ ref_text=ref_audio_text, # Pass the reference text
161
+ gen_text=text,
162
+ file_wave=out_path,
163
+ remove_silence=self.config.get('remove_silence', True),
164
+ speed=self.config.get('speed', 1.0)
165
+ )
166
+
167
+ if os.path.exists(out_path):
168
+ gen_duration = get_audio_duration(out_path)
169
+
170
+ segments_metadata.append({
171
+ 'index': idx,
172
+ 'path': out_path,
173
+ 'start_time': start_time,
174
+ 'end_time': end_time,
175
+ 'original_duration': original_duration,
176
+ 'gen_duration': gen_duration
177
+ })
178
+
179
+ except Exception as e:
180
+ logger.error(f"Inference failed for segment {idx}: {e}")
181
+
182
+ if not segments_metadata:
183
+ raise Exception("No audio generated.")
184
+
185
+ return {
186
+ 'task_id': task_id,
187
+ 'segments': segments_metadata,
188
+ 'task_dir': task_out_dir,
189
+ 'hook_url': data.get('hook_url'),
190
+ 'video_url': data.get('video_url'),
191
+ 'priority': task.get('priority', 3)
192
+ }
services/uvr5_service.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Tuple, Dict, Any
4
+ from pathlib import Path
5
+ from audio_separator.separator import Separator
6
+ from uvr5.download_models import download_model
7
+
8
+ logger = logging.getLogger("services.uvr5")
9
+
10
+ class UVR5Service:
11
+ def __init__(self, config: Dict[str, Any]):
12
+ self.config = config.get('uvr5', {})
13
+ self.model_dir = self.config.get('model_dir', './models/uvr5')
14
+ self.output_dir = self.config.get('output_dir', './temp/uvr5')
15
+
16
+ # Default model if not specified in config
17
+ # 优先读取 uvr5_model,兼容旧的 model_name
18
+ self.model_name = self.config.get('uvr5_model') or self.config.get('model_name', 'UVR-MDX-NET-Inst_HQ_3.onnx')
19
+ os.makedirs(self.model_dir, exist_ok=True)
20
+ os.makedirs(self.output_dir, exist_ok=True)
21
+
22
+ self.separator = None
23
+ self._initialize_separator()
24
+
25
+ def _initialize_separator(self):
26
+ """Initialize the Audio Separator with configuration."""
27
+ try:
28
+ logger.info(f"Initializing UVR5 Separator with model: {self.model_name}")
29
+
30
+ # Ensure model exists locally
31
+ model_path = os.path.join(self.model_dir, f"{self.model_name}.onnx")
32
+ model_filename = self.model_name if self.model_name.endswith('.onnx') else f"{self.model_name}.onnx"
33
+
34
+ # if not os.path.exists(model_path):
35
+ # logger.info(f"Model {self.model_name} not found locally. Attempting to download...")
36
+ # if download_model(self.model_name, Path(self.model_dir)):
37
+ # logger.info(f"Successfully downloaded {self.model_name}")
38
+ # else:
39
+ # raise FileNotFoundError(f"Failed to download model {self.model_name}")
40
+
41
+ self.separator = Separator(
42
+ log_level=logging.INFO,
43
+ model_file_dir=self.model_dir,
44
+ output_dir=self.output_dir,
45
+ output_format="wav"
46
+ )
47
+
48
+ # Load the model upfront
49
+ self.separator.load_model(model_filename)
50
+ logger.info("UVR5 Model loaded successfully.")
51
+
52
+ except Exception as e:
53
+ logger.error(f"Failed to initialize UVR5: {e}")
54
+ raise e
55
+
56
+ def process_audio(self, input_path: str, task_id: str) -> Tuple[str, str]:
57
+ """
58
+ Separate audio into Vocals and Instrumental (BGM).
59
+ Returns: (vocals_path, instrumental_path)
60
+ """
61
+ if not self.separator:
62
+ self._initialize_separator()
63
+
64
+ try:
65
+ logger.info(f"Starting UVR5 separation on {input_path}")
66
+
67
+ # Run separation
68
+ output_files = self.separator.separate(input_path)
69
+
70
+ vocals_path = None
71
+ instrumental_path = None
72
+
73
+ # Identify output files based on naming convention
74
+ for filename in output_files:
75
+ full_path = os.path.join(self.output_dir, filename)
76
+
77
+ # Separation logic usually produces "*(Vocals).wav" and "*(Instrumental).wav"
78
+ if "Instrumental" in filename or "Inst" in filename:
79
+ instrumental_path = full_path
80
+ elif "Vocals" in filename:
81
+ vocals_path = full_path
82
+
83
+ # Fallback if naming detection fails (usually first is inst, second is vocal or vice versa depending on model)
84
+ # But 'audio-separator' usually returns clear names.
85
+
86
+ if not instrumental_path:
87
+ raise Exception("Could not identify Instrumental track from UVR5 output.")
88
+
89
+ logger.info(f"UVR5 Complete. BGM: {instrumental_path}")
90
+ return vocals_path, instrumental_path
91
+
92
+ except Exception as e:
93
+ logger.error(f"UVR5 Processing failed: {e}")
94
+ raise e
uv.lock CHANGED
The diff for this file is too large to render. See raw diff
 
uvr5/download_models.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ UVR 模型下载脚本
4
+ 支持下载多个 UVR 模型到指定目录
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import requests
11
+ from pathlib import Path
12
+ try:
13
+ from tqdm import tqdm
14
+ HAS_TQDM = True
15
+ except ImportError:
16
+ HAS_TQDM = False
17
+ print("提示: 安装 tqdm 可以显示下载进度: pip install tqdm")
18
+
19
+ try:
20
+ from audio_separator.separator import Separator
21
+ HAS_AUDIO_SEPARATOR = True
22
+ except ImportError:
23
+ HAS_AUDIO_SEPARATOR = False
24
+ print("警告: audio-separator 未安装,将使用手动下载模式")
25
+
26
+ # 默认配置
27
+ DEFAULT_MODEL_DIR = './models/uvr5'
28
+ DEFAULT_MODEL_NAME = 'UVR-MDX-NET-Inst_HQ_3'
29
+
30
+ # 可用的模型列表
31
+ AVAILABLE_MODELS = {
32
+ 'UVR-MDX-NET-Inst_HQ_4': {
33
+ 'url': 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Inst_HQ_4.onnx',
34
+ 'size': '200 MB',
35
+ 'description': '高质量乐器分离模型(推荐)'
36
+ },
37
+ 'UVR-MDX-NET-Inst_HQ_3': {
38
+ 'url': 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Inst_HQ_3.onnx',
39
+ 'size': '200 MB',
40
+ 'description': '高质量乐器分离模型 v3'
41
+ },
42
+ 'UVR_MDXNET_KARA_2': {
43
+ 'url': 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR_MDXNET_KARA_2.onnx',
44
+ 'size': '200 MB',
45
+ 'description': 'Karaoke 人声分离模型'
46
+ },
47
+ 'Kim_Vocal_2': {
48
+ 'url': 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Vocal_2.onnx',
49
+ 'size': '200 MB',
50
+ 'description': 'Kim 人声分离模型'
51
+ },
52
+ }
53
+
54
+
55
+ def download_file(url: str, output_path: Path, model_name: str):
56
+ """
57
+ 下载文件并显示进度条
58
+
59
+ Args:
60
+ url: 下载地址
61
+ output_path: 保存路径
62
+ model_name: 模型名称
63
+ """
64
+ try:
65
+ print(f"\n📥 开始下载 {model_name}...")
66
+ print(f" URL: {url}")
67
+ print(f" 保存到: {output_path}")
68
+
69
+ response = requests.get(url, stream=True, timeout=30)
70
+ response.raise_for_status()
71
+
72
+ # 获取文件大小
73
+ total_size = int(response.headers.get('content-length', 0))
74
+
75
+ # 创建进度条
76
+ if HAS_TQDM:
77
+ with open(output_path, 'wb') as f, tqdm(
78
+ desc=model_name,
79
+ total=total_size,
80
+ unit='iB',
81
+ unit_scale=True,
82
+ unit_divisor=1024,
83
+ ) as pbar:
84
+ for chunk in response.iter_content(chunk_size=8192):
85
+ size = f.write(chunk)
86
+ pbar.update(size)
87
+ else:
88
+ # 无进度条模式
89
+ with open(output_path, 'wb') as f:
90
+ downloaded = 0
91
+ for chunk in response.iter_content(chunk_size=8192):
92
+ f.write(chunk)
93
+ downloaded += len(chunk)
94
+ if total_size > 0:
95
+ percent = (downloaded / total_size) * 100
96
+ print(f"\r 下载进度: {percent:.1f}%", end='', flush=True)
97
+ print() # 换行
98
+
99
+ print(f"✅ {model_name} 下载完成")
100
+ return True
101
+
102
+ except requests.exceptions.RequestException as e:
103
+ print(f"❌ 下载失败: {str(e)}")
104
+ # 删除不完整的文件
105
+ if output_path.exists():
106
+ output_path.unlink()
107
+ return False
108
+ except Exception as e:
109
+ print(f"❌ 发生错误: {str(e)}")
110
+ if output_path.exists():
111
+ output_path.unlink()
112
+ return False
113
+
114
+
115
+ def download_with_audio_separator(model_name: str, model_dir: Path):
116
+ """
117
+ 使用 audio-separator 内置功能下载模型
118
+
119
+ Args:
120
+ model_name: 模型名称(不带扩展名)
121
+ model_dir: 模型目录
122
+ """
123
+ try:
124
+ print(f"\n📥 使用 audio-separator 下载 {model_name}...")
125
+
126
+ # 创建 Separator 实例
127
+ separator = Separator(
128
+ log_level=30, # WARNING level
129
+ model_file_dir=str(model_dir),
130
+ output_dir=str(model_dir)
131
+ )
132
+
133
+ # 尝试加载模型,如果不存在会自动下载
134
+ model_filename = model_name if model_name.endswith('.onnx') else f"{model_name}.onnx"
135
+ separator.load_model(model_filename)
136
+
137
+ print(f"✅ {model_name} 下载/加载完成")
138
+ return True
139
+
140
+ except Exception as e:
141
+ print(f"❌ audio-separator 下载失败: {str(e)}")
142
+ return False
143
+
144
+
145
+ def list_models():
146
+ """列出所有可用的模型"""
147
+ print("\n📋 可用的 UVR 模型:")
148
+ print("=" * 70)
149
+ for model_name, info in AVAILABLE_MODELS.items():
150
+ print(f"\n模型名称: {model_name}")
151
+ print(f" 大小: {info['size']}")
152
+ print(f" 说明: {info['description']}")
153
+ print("\n" + "=" * 70)
154
+
155
+
156
+ def check_model_exists(model_name: str, model_dir: Path) -> bool:
157
+ """检查模型是否已存在"""
158
+ model_path = model_dir / f"{model_name}.onnx"
159
+ if model_path.exists():
160
+ file_size = model_path.stat().st_size
161
+ size_mb = file_size / (1024 * 1024)
162
+ print(f"✓ {model_name} 已存在 ({size_mb:.1f} MB)")
163
+ return True
164
+ return False
165
+
166
+
167
+ def download_model(model_name: str, model_dir: Path, force: bool = False):
168
+ """
169
+ 下载指定模型
170
+
171
+ Args:
172
+ model_name: 模型名称
173
+ model_dir: 模型目录
174
+ force: 是否强制重新下载
175
+ """
176
+ # 检查模型是否已存在
177
+ model_path = model_dir / f"{model_name}.onnx"
178
+ if model_path.exists() and not force:
179
+ print(f"\n✓ {model_name} 已存在")
180
+ print(f" 路径: {model_path}")
181
+ print(f" 使用 --force 强制重新下载")
182
+ return True
183
+
184
+ # 方式1: 优先使用 audio-separator 内置下载(支持更多模型)
185
+ if HAS_AUDIO_SEPARATOR:
186
+ if download_with_audio_separator(model_name, model_dir):
187
+ return True
188
+ print(" 尝试手动下载...")
189
+
190
+ # 方式2: 手动下载(仅支持列表中的模型)
191
+ if model_name not in AVAILABLE_MODELS:
192
+ print(f"❌ 模型 {model_name} 无法通过手动下载")
193
+ print(f" 请安装 audio-separator 或使用以下模型之一:")
194
+ for name in AVAILABLE_MODELS.keys():
195
+ print(f" - {name}")
196
+ return False
197
+
198
+ # 下载模型
199
+ model_info = AVAILABLE_MODELS[model_name]
200
+ return download_file(model_info['url'], model_path, model_name)
201
+
202
+
203
+ def download_all_models(model_dir: Path, force: bool = False):
204
+ """下载所有模型"""
205
+ print(f"\n📦 准备下载所有模型到: {model_dir}")
206
+
207
+ success_count = 0
208
+ total_count = len(AVAILABLE_MODELS)
209
+
210
+ for model_name in AVAILABLE_MODELS.keys():
211
+ if download_model(model_name, model_dir, force):
212
+ success_count += 1
213
+
214
+ print(f"\n{'=' * 70}")
215
+ print(f"下载完成: {success_count}/{total_count} 个模型成功")
216
+ print(f"{'=' * 70}")
217
+
218
+ return success_count == total_count
219
+
220
+
221
+ def main():
222
+ parser = argparse.ArgumentParser(
223
+ description='UVR 模型下载工具',
224
+ formatter_class=argparse.RawDescriptionHelpFormatter,
225
+ epilog="""
226
+ 示例:
227
+ # 列出所有可用模型
228
+ python download_models.py --list
229
+
230
+ # 下载默认模型 (UVR-MDX-NET-Inst_HQ_4)
231
+ python download_models.py
232
+
233
+ # 下载指定模型
234
+ python download_models.py --model UVR_MDXNET_KARA_2
235
+
236
+ # 下载所有模型
237
+ python download_models.py --all
238
+
239
+ # 指定模型目录
240
+ python download_models.py --dir /path/to/models
241
+
242
+ # 强制重新下载
243
+ python download_models.py --force
244
+ """
245
+ )
246
+
247
+ parser.add_argument(
248
+ '--list', '-l',
249
+ action='store_true',
250
+ help='列出所有可用的模型'
251
+ )
252
+
253
+ parser.add_argument(
254
+ '--model', '-m',
255
+ type=str,
256
+ help='要下载的模型名称'
257
+ )
258
+
259
+ parser.add_argument(
260
+ '--all', '-a',
261
+ action='store_true',
262
+ help='下载所有可用的模型'
263
+ )
264
+
265
+ parser.add_argument(
266
+ '--dir', '-d',
267
+ type=str,
268
+ default=DEFAULT_MODEL_DIR,
269
+ help=f'模型保存目录 (默认: {DEFAULT_MODEL_DIR})'
270
+ )
271
+
272
+ parser.add_argument(
273
+ '--force', '-f',
274
+ action='store_true',
275
+ help='强制重新下载已存在的模型'
276
+ )
277
+
278
+ args = parser.parse_args()
279
+
280
+ # 列出模型
281
+ if args.list:
282
+ list_models()
283
+ return 0
284
+
285
+ # 创建模型目录
286
+ model_dir = Path(args.dir)
287
+ model_dir.mkdir(parents=True, exist_ok=True)
288
+
289
+ print(f"\n🎵 UVR 模型下载工具")
290
+ print(f"模型目录: {model_dir.absolute()}")
291
+
292
+ # 下载所有模型
293
+ if args.all:
294
+ success = download_all_models(model_dir, args.force)
295
+ return 0 if success else 1
296
+
297
+ # 下载指定模型
298
+ if args.model:
299
+ success = download_model(args.model, model_dir, args.force)
300
+ return 0 if success else 1
301
+
302
+ # 默认下载配置的模型
303
+ default_model = DEFAULT_MODEL_NAME
304
+ print(f"\n使用默认模型: {default_model}")
305
+ print("提示: 使用 --list 查看所有可用模型")
306
+
307
+ success = download_model(default_model, model_dir, args.force)
308
+
309
+ if success:
310
+ print(f"\n✅ 模型已准备就绪")
311
+ print(f" 模型: {default_model}")
312
+ print(f" 路径: {model_dir / f'{default_model}.onnx'}")
313
+ print(f"\n💡 现在可以启动服务:")
314
+ print(f" ./start_local.sh # 本地运行")
315
+ print(f" ./start.sh # Docker 运行")
316
+ return 0
317
+ else:
318
+ return 1
319
+
320
+
321
+ if __name__ == '__main__':
322
+ try:
323
+ sys.exit(main())
324
+ except KeyboardInterrupt:
325
+ print("\n\n⚠️ 下载已取消")
326
+ sys.exit(1)
327
+ except Exception as e:
328
+ print(f"\n❌ 发生错误: {str(e)}")
329
+ sys.exit(1)
uvr5/processor.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ UVR Audio Processor Consumer
4
+ Consumes tasks from Redis priority queue, performs audio separation, and sends results
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ import time
11
+ import signal
12
+ import sys
13
+ import requests
14
+ import tempfile
15
+ from audio_separator.separator import Separator
16
+ import config
17
+ from redis_queue import create_redis_client, RedisPriorityQueue
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ class UVRProcessor:
27
+ """UVR audio separation processor with graceful shutdown"""
28
+
29
+ def __init__(self):
30
+ """Initialize UVR processor with model loaded once"""
31
+ self.separator = None
32
+ self.redis_client = None
33
+ self.task_queue = None
34
+ self.result_queue = None
35
+ self.shutdown_flag = False
36
+
37
+ # 注册信号处理器用于优雅关闭
38
+ signal.signal(signal.SIGINT, self._signal_handler)
39
+ signal.signal(signal.SIGTERM, self._signal_handler)
40
+
41
+ self._load_model()
42
+ self._connect_redis()
43
+
44
+ def _signal_handler(self, signum, frame):
45
+ """Handle shutdown signals gracefully"""
46
+ sig_name = 'SIGTERM' if signum == signal.SIGTERM else 'SIGINT'
47
+ logger.info(f"Received {sig_name}, initiating graceful shutdown...")
48
+ self.shutdown_flag = True
49
+
50
+ def _check_gpu_support(self):
51
+ """Check if GPU/CUDA is available"""
52
+ import platform
53
+
54
+ is_linux = platform.system() == 'Linux'
55
+ cuda_available = False
56
+ gpu_info = "CPU only"
57
+
58
+ try:
59
+ import onnxruntime as ort
60
+ providers = ort.get_available_providers()
61
+
62
+ if 'CUDAExecutionProvider' in providers:
63
+ cuda_available = True
64
+ gpu_info = "CUDA available"
65
+ logger.info("✓ CUDA Execution Provider detected")
66
+ elif 'TensorrtExecutionProvider' in providers:
67
+ cuda_available = True
68
+ gpu_info = "TensorRT available"
69
+ logger.info("✓ TensorRT Execution Provider detected")
70
+ else:
71
+ logger.info("GPU providers not available, using CPU")
72
+ except Exception as e:
73
+ logger.warning(f"Failed to check GPU support: {e}")
74
+
75
+ return is_linux and cuda_available, gpu_info
76
+
77
+ def _load_model(self):
78
+ """Load UVR model - called once on startup (with GPU optimization if available)"""
79
+ try:
80
+ logger.info(f"Loading UVR model: {config.MODEL_NAME}")
81
+ logger.info(f"Model directory: {config.MODEL_FILE_DIR}")
82
+
83
+ # Check GPU support
84
+ use_gpu, gpu_info = self._check_gpu_support()
85
+ logger.info(f"Hardware acceleration: {gpu_info}")
86
+
87
+ # Configure Separator
88
+ # Note: audio-separator automatically uses GPU if onnxruntime-gpu is installed
89
+ # and CUDAExecutionProvider is available. No explicit parameter needed.
90
+ separator_kwargs = {
91
+ 'log_level': logging.INFO,
92
+ 'model_file_dir': config.MODEL_FILE_DIR,
93
+ 'output_dir': config.OUTPUT_DIR
94
+ }
95
+
96
+ if use_gpu:
97
+ logger.info("🚀 GPU acceleration will be used automatically (onnxruntime-gpu detected)")
98
+ else:
99
+ logger.info("Running on CPU mode")
100
+
101
+ self.separator = Separator(**separator_kwargs)
102
+
103
+ # 确保模型名称包含 .onnx 扩展名
104
+ model_filename = config.MODEL_NAME
105
+ if not model_filename.endswith('.onnx'):
106
+ model_filename = f"{model_filename}.onnx"
107
+
108
+ # 检查模型文件是否存在
109
+ model_path = os.path.join(config.MODEL_FILE_DIR, model_filename)
110
+ if not os.path.exists(model_path):
111
+ logger.warning(f"Model file not found: {model_path}")
112
+ logger.info("Attempting to download model automatically...")
113
+ # audio-separator 会自动下载模型
114
+
115
+ # Load the specific model
116
+ self.separator.load_model(model_filename)
117
+
118
+ if use_gpu:
119
+ logger.info("✅ UVR model loaded successfully with GPU acceleration")
120
+ else:
121
+ logger.info("✅ UVR model loaded successfully (CPU mode)")
122
+
123
+ except Exception as e:
124
+ logger.error(f"Failed to load UVR model: {str(e)}")
125
+ logger.error(f"Please ensure model exists at: {config.MODEL_FILE_DIR}/{config.MODEL_NAME}.onnx")
126
+ logger.error(f"You can download it using: python3 download_models.py")
127
+ raise
128
+
129
+ def _connect_redis(self):
130
+ """Initialize Redis client and priority queues"""
131
+ try:
132
+ # Create Redis client
133
+ self.redis_client = create_redis_client(
134
+ host=config.REDIS_HOST,
135
+ port=config.REDIS_PORT,
136
+ db=config.REDIS_DB,
137
+ password=config.REDIS_PASSWORD
138
+ )
139
+
140
+ # Initialize priority queues
141
+ self.task_queue = RedisPriorityQueue(self.redis_client, config.REDIS_TASK_QUEUE)
142
+ self.result_queue = RedisPriorityQueue(self.redis_client, config.REDIS_RESULT_QUEUE)
143
+
144
+ logger.info("Redis connections and queues initialized")
145
+ except Exception as e:
146
+ logger.error(f"Failed to connect to Redis: {str(e)}")
147
+ raise
148
+
149
+ # 下载功能已移到 API 层,此方法不再需要
150
+ # def _download_audio(self, audio_url, task_uuid):
151
+ # 已移至 app.py 的 download_audio() 函数
152
+
153
+ def _separate_audio(self, input_path, task_uuid):
154
+ """Perform audio separation using UVR"""
155
+ try:
156
+ logger.info(f"[{task_uuid}] Starting audio separation")
157
+
158
+ # Perform separation
159
+ output_files = self.separator.separate(input_path)
160
+
161
+ logger.info(f"[{task_uuid}] Separation complete. Output files: {output_files}")
162
+
163
+ # Find vocals and instrumental files
164
+ vocals_path = None
165
+ instrumental_path = None
166
+
167
+ for file_path in output_files:
168
+ # 确保使用完整路径
169
+ # audio-separator 返回的可能是相对路径或文件名
170
+ if not os.path.isabs(file_path):
171
+ # 如果是相对路径,组合输出目录
172
+ full_path = os.path.join(config.OUTPUT_DIR, file_path)
173
+ if not os.path.exists(full_path):
174
+ # 尝试直接使用返回的路径
175
+ full_path = file_path
176
+ else:
177
+ full_path = file_path
178
+
179
+ # 检查文件名判断类型
180
+ filename = os.path.basename(full_path).lower()
181
+ if 'vocals' in filename or 'voice' in filename:
182
+ vocals_path = full_path
183
+ elif 'instrumental' in filename or 'inst' in filename:
184
+ instrumental_path = full_path
185
+
186
+ # If not found by name, use order
187
+ if not vocals_path and len(output_files) > 0:
188
+ vocals_path = output_files[0]
189
+ if not os.path.isabs(vocals_path):
190
+ vocals_path = os.path.join(config.OUTPUT_DIR, vocals_path)
191
+
192
+ if not instrumental_path and len(output_files) > 1:
193
+ instrumental_path = output_files[1]
194
+ if not os.path.isabs(instrumental_path):
195
+ instrumental_path = os.path.join(config.OUTPUT_DIR, instrumental_path)
196
+
197
+ # 验证文件存在
198
+ if not vocals_path or not os.path.exists(vocals_path):
199
+ raise Exception(f"Vocals file not found: {vocals_path}")
200
+ if not instrumental_path or not os.path.exists(instrumental_path):
201
+ raise Exception(f"Instrumental file not found: {instrumental_path}")
202
+
203
+ logger.info(f"[{task_uuid}] Vocals: {vocals_path}")
204
+ logger.info(f"[{task_uuid}] Instrumental: {instrumental_path}")
205
+
206
+ return vocals_path, instrumental_path
207
+
208
+ except Exception as e:
209
+ logger.error(f"[{task_uuid}] Separation failed: {str(e)}")
210
+ raise
211
+
212
+ def _process_task(self, task_data):
213
+ """Process a single task"""
214
+ task_uuid = task_data['task_uuid']
215
+ audio_path = task_data['audio_path'] # 改为直接使用本地路径
216
+ hook_url = task_data['hook_url']
217
+ priority = task_data.get('priority', config.DEFAULT_PRIORITY)
218
+
219
+ vocals_path = None
220
+ instrumental_path = None
221
+
222
+ try:
223
+ logger.info(f"[{task_uuid}] Processing task with priority {priority}")
224
+ logger.info(f"[{task_uuid}] Audio file: {audio_path}")
225
+
226
+ # 验证音频文件存在
227
+ if not os.path.exists(audio_path):
228
+ raise Exception(f"Audio file not found: {audio_path}")
229
+
230
+ # Separate audio (直接使用已下载的文件)
231
+ vocals_path, instrumental_path = self._separate_audio(audio_path, task_uuid)
232
+
233
+ # Send success result to result queue with same priority
234
+ result = {
235
+ 'task_uuid': task_uuid,
236
+ 'success': True,
237
+ 'vocals_path': vocals_path,
238
+ 'instrumental_path': instrumental_path,
239
+ 'hook_url': hook_url,
240
+ 'priority': priority
241
+ }
242
+
243
+ self.result_queue.enqueue(result, priority=priority)
244
+ logger.info(f"[{task_uuid}] Success result sent to result queue with priority {priority}")
245
+
246
+ except Exception as e:
247
+ logger.error(f"[{task_uuid}] Task processing failed: {str(e)}")
248
+
249
+ # Send failure result with same priority
250
+ result = {
251
+ 'task_uuid': task_uuid,
252
+ 'success': False,
253
+ 'error_message': str(e),
254
+ 'hook_url': hook_url,
255
+ 'priority': priority
256
+ }
257
+
258
+ self.result_queue.enqueue(result, priority=priority)
259
+ logger.info(f"[{task_uuid}] Failure result sent to result queue with priority {priority}")
260
+
261
+ finally:
262
+ # Clean up input file
263
+ if audio_path and os.path.exists(audio_path):
264
+ try:
265
+ os.remove(audio_path)
266
+ logger.info(f"[{task_uuid}] Cleaned up input file")
267
+ except Exception as e:
268
+ logger.warning(f"[{task_uuid}] Failed to cleanup input file: {e}")
269
+
270
+ def start(self):
271
+ """Start consuming and processing tasks"""
272
+ logger.info("UVR Processor started, waiting for tasks from Redis priority queue...")
273
+
274
+ consecutive_errors = 0
275
+ max_consecutive_errors = 10
276
+
277
+ try:
278
+ while not self.shutdown_flag:
279
+ try:
280
+ # Blocking dequeue with 5 second timeout
281
+ task_data = self.task_queue.dequeue(timeout=5)
282
+
283
+ if task_data is None:
284
+ consecutive_errors = 0 # 重置错误计数
285
+ continue # No task available, continue polling
286
+
287
+ task_uuid = task_data.get('task_uuid', 'unknown')
288
+ priority = task_data.get('priority', config.DEFAULT_PRIORITY)
289
+ logger.info(f"Received task: {task_uuid} with priority {priority}")
290
+
291
+ try:
292
+ # 处理任务
293
+ self._process_task(task_data)
294
+ logger.info(f"[{task_uuid}] Task completed successfully")
295
+ consecutive_errors = 0 # 重置错误计数
296
+
297
+ except Exception as e:
298
+ logger.error(f"Error processing task {task_uuid}: {str(e)}")
299
+ # 任务已经从队列中移除,错误已在 _process_task 中处理
300
+
301
+ except Exception as e:
302
+ consecutive_errors += 1
303
+ logger.error(f"Redis processor error ({consecutive_errors}/{max_consecutive_errors}): {e}")
304
+
305
+ if consecutive_errors >= max_consecutive_errors:
306
+ logger.critical(f"Too many consecutive errors ({consecutive_errors}), stopping processor")
307
+ break
308
+
309
+ # 根据错误次数调整休眠时间
310
+ sleep_time = min(consecutive_errors * 2, 30) # 最多休眠30秒
311
+ time.sleep(sleep_time)
312
+
313
+ except KeyboardInterrupt:
314
+ logger.info("Received keyboard interrupt")
315
+ except Exception as e:
316
+ if not self.shutdown_flag:
317
+ logger.error(f"Processor error: {str(e)}")
318
+ raise
319
+ finally:
320
+ logger.info("Processor shutting down...")
321
+ self.close()
322
+
323
+ def close(self):
324
+ """Close connections gracefully"""
325
+ logger.info("Closing Redis connections...")
326
+
327
+ # 关闭 Redis client
328
+ if self.redis_client:
329
+ try:
330
+ logger.info("Closing Redis client...")
331
+ self.redis_client.close()
332
+ logger.info("Redis client closed")
333
+ except Exception as e:
334
+ logger.error(f"Error closing Redis client: {e}")
335
+
336
+ logger.info("Processor shutdown complete")
337
+
338
+ if __name__ == '__main__':
339
+ processor = UVRProcessor()
340
+ processor.start()