Instructions to use nvidia/Cosmos-Embed1-224p with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Cosmos
How to use nvidia/Cosmos-Embed1-224p with Cosmos:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- NeMo
How to use nvidia/Cosmos-Embed1-224p with NeMo:
# tag did not correspond to a valid NeMo domain.
- Notebooks
- Google Colab
- Kaggle
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Cosmos-Embed1 text+video embedder.""" | |
| import math | |
| from copy import deepcopy | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from transformers import AutoModel, PreTrainedModel | |
| from .configuration_embed1 import CosmosEmbed1Config | |
| from .modeling_outputs import TextEmbedderOutput, TextVideoEmbedderOutput, VideoEmbedderOutput | |
| from .modeling_qformer import BertLMHeadModel, load_qformer | |
| from .modeling_utils import EncodingFactory, rank0_first | |
| from .modeling_vit import EvaViTG | |
| class CosmosEmbed1(PreTrainedModel): | |
| config_class = CosmosEmbed1Config | |
| def __init__(self, config: CosmosEmbed1Config) -> None: | |
| """Cosmos-Embed1 video embedder constructor. | |
| Args: | |
| config (CosmosEmbed1Config): Model configuration. | |
| """ | |
| super().__init__(config) | |
| self.embed_dim = config.embed_dim | |
| self.num_query_tokens = config.num_query_tokens | |
| self.num_video_frames = config.num_video_frames | |
| self.temporal_encoding_type = config.temporal_encoding_type | |
| self.resolution = config.resolution | |
| self.vocab_size = config.vocab_size | |
| self.transformer_engine = config.transformer_engine | |
| self.use_fp8 = config.use_fp8 | |
| # visual encoder initialization | |
| self.register_buffer( | |
| "normalization_mean", | |
| torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1), | |
| persistent=False, | |
| ) | |
| self.register_buffer( | |
| "normalization_std", | |
| torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1), | |
| persistent=False, | |
| ) | |
| self.visual_encoder = EvaViTG( | |
| img_size=self.resolution, | |
| transformer_engine=self.transformer_engine, | |
| use_fp8=self.use_fp8, | |
| ) | |
| self.ln_vision = nn.LayerNorm(self.visual_encoder.embed_dim) | |
| # qformer initialization | |
| self.qformer, self.query_tokens = self._init_qformer( | |
| num_query_tokens=self.num_query_tokens, | |
| encoder_width=self.visual_encoder.embed_dim, | |
| vocab_size=self.vocab_size, | |
| ) | |
| # self.qformer. | |
| state_dict = self.qformer.state_dict() | |
| for name, param in self.qformer.named_parameters(): | |
| if "_query" in name: | |
| key_orig = name.replace("_query", "") | |
| param.data.copy_(state_dict[key_orig]) | |
| # temporal encoding | |
| self.temporal_encoding = EncodingFactory( | |
| self.temporal_encoding_type, | |
| embed_dim=self.visual_encoder.embed_dim, | |
| max_len=self.num_video_frames, | |
| ) | |
| # output projections | |
| self.vision_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) | |
| self.text_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) | |
| self.itm_proj = nn.Linear(self.qformer.config.hidden_size, 2) | |
| # initialize logit scale/bias like SigLIP (as per Table 4 in https://arxiv.org/pdf/2303.15343) | |
| self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0))) | |
| self.logit_bias = nn.Parameter(torch.tensor(-10.0)) | |
| def hidden_dim(self) -> int: | |
| return self.visual_encoder.embed_dim | |
| def no_weight_decay(self) -> set: | |
| ret = {"logit_scale", "logit_bias"} | |
| return ret | |
| def forward( | |
| self, | |
| videos: torch.FloatTensor, | |
| input_ids: torch.LongTensor, | |
| attention_mask: torch.FloatTensor, | |
| ) -> TextVideoEmbedderOutput: | |
| """Forward function for `ComosEmbed1`. | |
| Args: | |
| videos (`torch.Tensor` of shape `(batch_size, num_frames, RGB, height, width)`): | |
| batched videos with fixed number of RGB frames. | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. | |
| Indices can be obtained by using [`AutoTokenizer`, `CosmosEmbed1Tokenizer`]. | |
| attention_mask: (`torch.Tensor` of shape `(batch_size, sequence_length)`): | |
| Mask to avoid performing attention on padding token indices. | |
| Mask values select in `[0, 1]`. | |
| - 1 for tokens that are **not masked**. | |
| - 0 for tokens that are **masked**. | |
| """ | |
| video_output = self.get_video_embeddings(videos) | |
| text_output = self.get_text_embeddings(input_ids, attention_mask) | |
| return TextVideoEmbedderOutput(**video_output, **text_output) | |
| def get_video_embeddings(self, videos: torch.Tensor) -> VideoEmbedderOutput: | |
| videos = (videos - self.normalization_mean) / self.normalization_std | |
| batch_size, num_frames, _, H, W = videos.shape | |
| frame_batch = rearrange(videos, "b t c h w -> (b t) c h w") | |
| # process video frames through ViT | |
| visual_embs = self.visual_encoder(frame_batch) | |
| visual_embs = self.ln_vision(visual_embs) | |
| visual_embs = rearrange( | |
| visual_embs, | |
| "(b t) k d -> b t k d", | |
| b=batch_size, | |
| t=num_frames, | |
| k=visual_embs.size(1), | |
| d=visual_embs.size(2), | |
| ) | |
| # add temporal encoding | |
| visual_embs = self.temporal_encoding(visual_embs) | |
| # Q-Former cross-attention | |
| encoder_hidden_states = rearrange(visual_embs, "b t k d -> b (t k) d") | |
| encoder_attention_mask = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(videos.device) | |
| query_tokens = self.query_tokens.expand(encoder_hidden_states.size(0), -1, -1) | |
| visual_query_output = self.qformer.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| use_cache=True, | |
| return_dict=True, | |
| ) | |
| visual_cls_tokens = visual_query_output.last_hidden_state.mean(dim=1, keepdim=False) | |
| visual_proj = self.vision_proj(visual_cls_tokens) | |
| visual_proj = F.normalize(visual_proj, dim=-1) | |
| # reshape visual embs to (B,T,H,W,D), to confirm with expected output. | |
| # separate out the frame-level cls tokens if necessary. | |
| frame_cls_tokens, visual_embs = visual_embs[:, :, 0:1], visual_embs[:, :, 1:] | |
| h = H // self.visual_encoder.patch_size | |
| w = W // self.visual_encoder.patch_size | |
| visual_embs = rearrange(visual_embs, "b t (h w) d -> b t h w d", h=h, w=w) | |
| return VideoEmbedderOutput( | |
| visual_proj=visual_proj, | |
| visual_embs=visual_embs, | |
| visual_query_output=visual_query_output, | |
| visual_cls_tokens=visual_cls_tokens, | |
| frame_cls_tokens=frame_cls_tokens, | |
| ) | |
| def get_text_embeddings( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: torch.FloatTensor, | |
| ) -> TextEmbedderOutput: | |
| text_query_output = self.qformer.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask.to(dtype=self.query_tokens.dtype), | |
| return_dict=True, | |
| ) | |
| text_proj = text_query_output.last_hidden_state[:, 0, :] | |
| text_proj = self.text_proj(text_proj) | |
| text_proj = F.normalize(text_proj, dim=-1) | |
| return TextEmbedderOutput( | |
| text_proj=text_proj, | |
| text_embs=text_query_output.last_hidden_state, | |
| text_query_output=text_query_output, | |
| ) | |
| def _init_qformer( | |
| cls: "CosmosEmbed1", | |
| num_query_tokens: int, | |
| encoder_width: int, | |
| vocab_size: int, | |
| hidden_size: int = 768, | |
| ) -> tuple[BertLMHeadModel, nn.Parameter]: | |
| """Convenience function for initializing QFormer module.""" | |
| qformer = load_qformer( | |
| num_query_tokens=num_query_tokens, | |
| encoder_width=encoder_width, | |
| hidden_size=hidden_size, | |
| vocab_size=vocab_size, | |
| ) | |
| query_tokens = nn.Parameter(torch.zeros(1, num_query_tokens, hidden_size)) | |
| query_tokens.data.normal_(mean=0.0, std=0.02) | |
| return qformer, query_tokens | |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): | |
| # Get config from kwargs or load from pretrained path | |
| config = kwargs.get("config", None) | |
| if config is None: | |
| config = CosmosEmbed1Config.from_pretrained(pretrained_model_name_or_path) | |
| if config.transformer_engine: | |
| config_no_te = deepcopy(config) | |
| config_no_te.transformer_engine = False | |
| config_no_te.use_fp8 = False # Also disable FP8 for the base model | |
| # Remove 'config' from kwargs to avoid conflict, we'll pass config_no_te | |
| kwargs_no_te = deepcopy(kwargs) | |
| kwargs_no_te["config"] = config_no_te | |
| # Load standard (non-TE) model & weights | |
| base_model = super().from_pretrained(pretrained_model_name_or_path, **kwargs_no_te) | |
| base_state_dict = base_model.state_dict() | |
| # Now build the TE version of the model | |
| model_with_te = cls(config=config) | |
| # Load weights from non-TE model | |
| missing, unexpected = model_with_te.load_state_dict(base_state_dict, strict=False) | |
| # Optional debug log | |
| if missing: | |
| print(f"[TransformerEngine] Missing keys: {missing}") | |
| if unexpected: | |
| print(f"[TransformerEngine] Unexpected keys: {unexpected}") | |
| return model_with_te | |
| else: | |
| return super().from_pretrained(pretrained_model_name_or_path, **kwargs) | |
| AutoModel.register(CosmosEmbed1Config, CosmosEmbed1) | |