Update scripts/qwen3_vl_embedding.py

#3
Files changed (1) hide show
  1. scripts/qwen3_vl_embedding.py +65 -54
scripts/qwen3_vl_embedding.py CHANGED
@@ -114,33 +114,59 @@ class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel):
114
  attention_mask=attention_mask,
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # Define embedder class for processing inputs and generating embeddings
118
  class Qwen3VLEmbedder():
119
- def __init__(self, model_name_or_path: str, max_length: int = MAX_LENGTH,
120
- instruction: Optional[str] = None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
121
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
122
  self.max_length = max_length
123
- self.instruction = instruction or "Represent the user's input."
124
- # Set pixel and frame configurations
125
- self.min_pixels = kwargs.pop('min_pixels', MIN_PIXELS)
126
- self.max_pixels = kwargs.pop('max_pixels', MAX_PIXELS)
127
- self.total_pixels = kwargs.pop('total_pixels', MAX_TOTAL_PIXELS)
128
- self.fps = kwargs.pop('fps', FPS)
129
- self.num_frames = kwargs.pop('num_frames', MAX_FRAMES)
130
- self.max_frames = kwargs.pop('max_frames', MAX_FRAMES)
131
-
132
- # Initialize model and processor
133
  self.model = Qwen3VLForEmbedding.from_pretrained(
134
  model_name_or_path, trust_remote_code=True, **kwargs
135
  ).to(device)
136
  self.processor = Qwen3VLProcessor.from_pretrained(
137
  model_name_or_path, padding_side='right'
138
  )
 
139
 
140
- # Define padding token id
141
- self.model.eval() # Set model to evaluation mode
142
-
143
- # Forward pass for the embedder model
144
  @torch.no_grad()
145
  def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
146
  outputs = self.model(**inputs)
@@ -149,25 +175,6 @@ class Qwen3VLEmbedder():
149
  'attention_mask': inputs.get('attention_mask')
150
  }
151
 
152
- # Sample frames from video files
153
- def _sample_frames(self, frames: List[str], num_segments: int, max_segments: int) -> List[str]:
154
- duration = len(frames)
155
- frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
156
- frame_id_list = frame_id_array.tolist()
157
- last_frame_id = frame_id_list[-1]
158
-
159
- # Create a list of sampled frames
160
- sampled_frames = []
161
- for frame_idx in frame_id_list:
162
- try:
163
- sampled_frames.append(frames[frame_idx])
164
- except:
165
- break
166
- # Ensure the sampled list meets the required segment count
167
- while len(sampled_frames) < num_segments:
168
- sampled_frames.append(frames[last_frame_id])
169
- return sampled_frames[:max_segments]
170
-
171
  # Truncate token sequence to a specified max length
172
  def _truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]:
173
  if len(token_ids) <= max_length:
@@ -189,12 +196,14 @@ class Qwen3VLEmbedder():
189
  return final_token_ids
190
 
191
  # Format input based on provided text, image, video, and instruction
192
- def format_model_input(self, text: Optional[str] = None,
193
- image: Optional[Union[str, Image.Image]] = None,
194
- video: Optional[Union[str, List[str]]] = None,
195
- instruction: Optional[str] = None,
196
- fps: Optional[float] = None,
197
- max_frames: Optional[int] = None) -> List[Dict]:
 
 
198
 
199
  # Ensure instruction ends with punctuation
200
  if instruction:
@@ -205,35 +214,37 @@ class Qwen3VLEmbedder():
205
  # Initialize conversation with system prompts
206
  content = []
207
  conversation = [
208
- {"role": "system", "content": [{"type": "text", "text": instruction or self.instruction}]},
209
  {"role": "user", "content": content}
210
  ]
211
 
212
  # Add text, image, or video content to conversation
213
  if not text and not image and not video:
214
- content.append({'type': 'text', 'text': ""})
215
  return conversation
216
 
217
  if video:
218
  video_content = None
 
219
  if isinstance(video, list):
220
  video_content = video
221
  if self.num_frames is not None or self.max_frames is not None:
222
- video_content = self._sample_frames(video_content, self.num_frames, self.max_frames)
223
- video_content = ['file://' + ele for ele in video_content]
 
 
 
224
  elif isinstance(video, str):
225
- video_content = video if video.startswith(('http', 'oss')) else 'file://' + video
 
226
  else:
227
- video_content = video
228
 
229
  # Add video input details to content
230
  if video_content:
231
  content.append({
232
  'type': 'video', 'video': video_content,
233
- 'total_pixels': self.total_pixels,
234
- 'max_frames': max_frames or self.max_frames,
235
- 'fps': fps or self.fps,
236
- 'sample_fps': fps or self.fps,
237
  })
238
 
239
  if image:
@@ -243,7 +254,7 @@ class Qwen3VLEmbedder():
243
  elif isinstance(image, str):
244
  image_content = image if image.startswith(('http', 'oss')) else 'file://' + image
245
  else:
246
- image_content = image
247
 
248
  # Add image input details to content
249
  if image_content:
@@ -270,7 +281,7 @@ class Qwen3VLEmbedder():
270
  return_video_metadata=True, return_video_kwargs=True
271
  )
272
  except Exception as e:
273
- logger.warning(f"Error in processing vision info: {e}")
274
  images = None
275
  video_inputs = None
276
  video_kwargs = {'do_sample_frames': False}
@@ -323,4 +334,4 @@ class Qwen3VLEmbedder():
323
  if normalize:
324
  embeddings = F.normalize(embeddings, p=2, dim=-1)
325
 
326
- return embeddings
 
114
  attention_mask=attention_mask,
115
  )
116
 
117
+ def sample_frames(frames: List[Union[str, Image.Image]], num_segments: int, max_segments: int) -> List[str]:
118
+ duration = len(frames)
119
+ frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
120
+ frame_id_list = frame_id_array.tolist()
121
+ last_frame_id = frame_id_list[-1]
122
+
123
+ # Create a list of sampled frames
124
+ sampled_frames = []
125
+ for frame_idx in frame_id_list:
126
+ try:
127
+ sampled_frames.append(frames[frame_idx])
128
+ except:
129
+ break
130
+ # Ensure the sampled list meets the required segment count
131
+ while len(sampled_frames) < num_segments:
132
+ sampled_frames.append(frames[last_frame_id])
133
+ return sampled_frames[:max_segments]
134
+
135
  # Define embedder class for processing inputs and generating embeddings
136
  class Qwen3VLEmbedder():
137
+ def __init__(
138
+ self,
139
+ model_name_or_path: str,
140
+ max_length: int = MAX_LENGTH,
141
+ min_pixels: int = MIN_PIXELS,
142
+ max_pixels: int = MAX_PIXELS,
143
+ total_pixels: int = MAX_TOTAL_PIXELS,
144
+ fps: float = FPS,
145
+ num_frames: int = MAX_FRAMES,
146
+ max_frames: int = MAX_FRAMES,
147
+ default_instruction: str = "Represent the user's input.",
148
+ **kwargs
149
+ ):
150
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
151
+
152
  self.max_length = max_length
153
+ self.min_pixels = min_pixels
154
+ self.max_pixels = max_pixels
155
+ self.total_pixels = total_pixels
156
+ self.fps = fps
157
+ self.num_frames = num_frames
158
+ self.max_frames = max_frames
159
+
160
+ self.default_instruction = default_instruction
161
+
 
162
  self.model = Qwen3VLForEmbedding.from_pretrained(
163
  model_name_or_path, trust_remote_code=True, **kwargs
164
  ).to(device)
165
  self.processor = Qwen3VLProcessor.from_pretrained(
166
  model_name_or_path, padding_side='right'
167
  )
168
+ self.model.eval()
169
 
 
 
 
 
170
  @torch.no_grad()
171
  def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
172
  outputs = self.model(**inputs)
 
175
  'attention_mask': inputs.get('attention_mask')
176
  }
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  # Truncate token sequence to a specified max length
179
  def _truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]:
180
  if len(token_ids) <= max_length:
 
196
  return final_token_ids
197
 
198
  # Format input based on provided text, image, video, and instruction
199
+ def format_model_input(
200
+ self, text: Optional[str] = None,
201
+ image: Optional[Union[str, Image.Image]] = None,
202
+ video: Optional[Union[str, List[Union[str, Image.Image]]]] = None,
203
+ instruction: Optional[str] = None,
204
+ fps: Optional[float] = None,
205
+ max_frames: Optional[int] = None
206
+ ) -> List[Dict]:
207
 
208
  # Ensure instruction ends with punctuation
209
  if instruction:
 
214
  # Initialize conversation with system prompts
215
  content = []
216
  conversation = [
217
+ {"role": "system", "content": [{"type": "text", "text": instruction or self.default_instruction}]},
218
  {"role": "user", "content": content}
219
  ]
220
 
221
  # Add text, image, or video content to conversation
222
  if not text and not image and not video:
223
+ content.append({'type': 'text', 'text': "NULL"})
224
  return conversation
225
 
226
  if video:
227
  video_content = None
228
+ video_kwargs = { 'total_pixels': self.total_pixels }
229
  if isinstance(video, list):
230
  video_content = video
231
  if self.num_frames is not None or self.max_frames is not None:
232
+ video_content = sample_frames(video_content, self.num_frames, self.max_frames)
233
+ video_content = [
234
+ ('file://' + ele if isinstance(ele, str) else ele)
235
+ for ele in video_content
236
+ ]
237
  elif isinstance(video, str):
238
+ video_content = video if video.startswith(('http://', 'https://')) else 'file://' + video
239
+ video_kwargs = {'fps': fps or self.fps, 'max_frames': max_frames or self.max_frames,}
240
  else:
241
+ raise TypeError(f"Unrecognized video type: {type(video)}")
242
 
243
  # Add video input details to content
244
  if video_content:
245
  content.append({
246
  'type': 'video', 'video': video_content,
247
+ **video_kwargs
 
 
 
248
  })
249
 
250
  if image:
 
254
  elif isinstance(image, str):
255
  image_content = image if image.startswith(('http', 'oss')) else 'file://' + image
256
  else:
257
+ raise TypeError(f"Unrecognized image type: {type(image)}")
258
 
259
  # Add image input details to content
260
  if image_content:
 
281
  return_video_metadata=True, return_video_kwargs=True
282
  )
283
  except Exception as e:
284
+ logger.error(f"Error in processing vision info: {e}")
285
  images = None
286
  video_inputs = None
287
  video_kwargs = {'do_sample_frames': False}
 
334
  if normalize:
335
  embeddings = F.normalize(embeddings, p=2, dim=-1)
336
 
337
+ return embeddings