| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import inspect |
| | from typing import List, Optional, Union, Dict |
| |
|
| | import torch |
| |
|
| | from diffusers import AutoencoderKLWan |
| | from diffusers.schedulers import UniPCMultistepScheduler |
| | from diffusers.utils import logging |
| | from diffusers.utils.torch_utils import randn_tensor |
| | from diffusers.modular_pipelines import ( |
| | ModularPipeline, |
| | ModularPipelineBlocks, |
| | SequentialPipelineBlocks, |
| | PipelineState, |
| | ) |
| | from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| | ComponentSpec, |
| | ConfigSpec, |
| | InputParam, |
| | OutputParam, |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | |
| | def retrieve_timesteps( |
| | scheduler, |
| | num_inference_steps: Optional[int] = None, |
| | device: Optional[Union[str, torch.device]] = None, |
| | timesteps: Optional[List[int]] = None, |
| | sigmas: Optional[List[float]] = None, |
| | **kwargs, |
| | ): |
| | r""" |
| | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| | |
| | Args: |
| | scheduler (`SchedulerMixin`): |
| | The scheduler to get timesteps from. |
| | num_inference_steps (`int`): |
| | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| | must be `None`. |
| | device (`str` or `torch.device`, *optional*): |
| | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| | timesteps (`List[int]`, *optional*): |
| | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| | `num_inference_steps` and `sigmas` must be `None`. |
| | sigmas (`List[float]`, *optional*): |
| | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| | `num_inference_steps` and `timesteps` must be `None`. |
| | |
| | Returns: |
| | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| | second element is the number of inference steps. |
| | """ |
| | if timesteps is not None and sigmas is not None: |
| | raise ValueError( |
| | "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" |
| | ) |
| | if timesteps is not None: |
| | accepts_timesteps = "timesteps" in set( |
| | inspect.signature(scheduler.set_timesteps).parameters.keys() |
| | ) |
| | if not accepts_timesteps: |
| | raise ValueError( |
| | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| | f" timestep schedules. Please check whether you are using the correct scheduler." |
| | ) |
| | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | num_inference_steps = len(timesteps) |
| | elif sigmas is not None: |
| | accept_sigmas = "sigmas" in set( |
| | inspect.signature(scheduler.set_timesteps).parameters.keys() |
| | ) |
| | if not accept_sigmas: |
| | raise ValueError( |
| | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| | f" sigmas schedules. Please check whether you are using the correct scheduler." |
| | ) |
| | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | num_inference_steps = len(timesteps) |
| | else: |
| | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | return timesteps, num_inference_steps |
| |
|
| |
|
| | def retrieve_latents( |
| | encoder_output: torch.Tensor, |
| | generator: Optional[torch.Generator] = None, |
| | sample_mode: str = "sample", |
| | ): |
| | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| | return encoder_output.latent_dist.sample(generator) |
| | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| | return encoder_output.latent_dist.mode() |
| | elif hasattr(encoder_output, "latents"): |
| | return encoder_output.latents |
| | else: |
| | raise AttributeError("Could not access latents of provided encoder_output") |
| |
|
| |
|
| | def _initialize_kv_cache( |
| | components: ModularPipeline, |
| | kv_cache_existing: Optional[List[Dict]], |
| | batch_size: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | local_attn_size: int, |
| | frame_seq_length: int, |
| | ): |
| | """ |
| | Initialize a Per-GPU KV cache for the Wan model. |
| | Mirrors causal_inference.py:279-313 |
| | """ |
| | kv_cache = [] |
| |
|
| | |
| | if local_attn_size != -1: |
| | |
| | kv_cache_size = local_attn_size * frame_seq_length |
| | else: |
| | |
| | kv_cache_size = 32760 |
| |
|
| | |
| | num_transformer_blocks = len(components.transformer.blocks) |
| | num_heads = components.transformer.config.num_heads |
| | dim = components.transformer.config.dim |
| | k_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads] |
| | v_shape = [batch_size, kv_cache_size, num_heads, dim // num_heads] |
| |
|
| | |
| | if ( |
| | kv_cache_existing |
| | and len(kv_cache_existing) > 0 |
| | and list(kv_cache_existing[0]["k"].shape) == k_shape |
| | and list(kv_cache_existing[0]["v"].shape) == v_shape |
| | ): |
| | for i in range(num_transformer_blocks): |
| | kv_cache_existing[i]["k"].zero_() |
| | kv_cache_existing[i]["v"].zero_() |
| | kv_cache_existing[i]["global_end_index"] = 0 |
| | kv_cache_existing[i]["local_end_index"] = 0 |
| | return kv_cache_existing |
| | else: |
| | |
| | for _ in range(num_transformer_blocks): |
| | kv_cache.append( |
| | { |
| | "k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(), |
| | "v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(), |
| | "global_end_index": 0, |
| | "local_end_index": 0, |
| | } |
| | ) |
| | return kv_cache |
| |
|
| |
|
| | def _initialize_crossattn_cache( |
| | components: ModularPipeline, |
| | crossattn_cache_existing: Optional[List[Dict]], |
| | batch_size: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | ): |
| | """ |
| | Initialize a Per-GPU cross-attention cache for the Wan model. |
| | Mirrors causal_inference.py:315-338 |
| | """ |
| | crossattn_cache = [] |
| |
|
| | |
| | num_transformer_blocks = len(components.transformer.blocks) |
| | num_heads = components.transformer.config.num_heads |
| | dim = components.transformer.config.dim |
| | k_shape = [batch_size, 512, num_heads, dim // num_heads] |
| | v_shape = [batch_size, 512, num_heads, dim // num_heads] |
| |
|
| | |
| | if ( |
| | crossattn_cache_existing |
| | and len(crossattn_cache_existing) > 0 |
| | and list(crossattn_cache_existing[0]["k"].shape) == k_shape |
| | and list(crossattn_cache_existing[0]["v"].shape) == v_shape |
| | ): |
| | for i in range(num_transformer_blocks): |
| | crossattn_cache_existing[i]["k"].zero_() |
| | crossattn_cache_existing[i]["v"].zero_() |
| | crossattn_cache_existing[i]["is_init"] = False |
| | return crossattn_cache_existing |
| | else: |
| | |
| | for _ in range(num_transformer_blocks): |
| | crossattn_cache.append( |
| | { |
| | "k": torch.zeros(k_shape, dtype=dtype, device=device).contiguous(), |
| | "v": torch.zeros(v_shape, dtype=dtype, device=device).contiguous(), |
| | "is_init": False, |
| | } |
| | ) |
| | return crossattn_cache |
| |
|
| |
|
| | class WanInputStep(ModularPipelineBlocks): |
| | model_name = "WanRT" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return ( |
| | "Input processing step that:\n" |
| | " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" |
| | " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" |
| | "All input tensors are expected to have either batch_size=1 or match the batch_size\n" |
| | "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" |
| | "have a final batch_size of batch_size * num_videos_per_prompt." |
| | ) |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam("num_videos_per_prompt", default=1), |
| | InputParam( |
| | "prompt_embeds", |
| | required=True, |
| | type_hint=torch.Tensor, |
| | description="Pre-generated text embeddings. Can be generated from text_encoder step.", |
| | ), |
| | InputParam( |
| | "negative_prompt_embeds", |
| | type_hint=torch.Tensor, |
| | description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[str]: |
| | return [ |
| | OutputParam( |
| | "batch_size", |
| | type_hint=int, |
| | description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", |
| | ), |
| | OutputParam( |
| | "dtype", |
| | type_hint=torch.dtype, |
| | description="Data type of model tensor inputs (determined by `prompt_embeds`)", |
| | ), |
| | OutputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="text embeddings used to guide the image generation", |
| | ), |
| | OutputParam( |
| | "negative_prompt_embeds", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="negative text embeddings used to guide the image generation", |
| | ), |
| | ] |
| |
|
| | def check_inputs(self, components, block_state): |
| | if ( |
| | block_state.prompt_embeds is not None |
| | and block_state.negative_prompt_embeds is not None |
| | ): |
| | if ( |
| | block_state.prompt_embeds.shape |
| | != block_state.negative_prompt_embeds.shape |
| | ): |
| | raise ValueError( |
| | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
| | f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" |
| | f" {block_state.negative_prompt_embeds.shape}." |
| | ) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | self.check_inputs(components, block_state) |
| |
|
| | block_state.batch_size = block_state.prompt_embeds.shape[0] |
| | block_state.dtype = block_state.prompt_embeds.dtype |
| |
|
| | _, seq_len, _ = block_state.prompt_embeds.shape |
| | block_state.prompt_embeds = block_state.prompt_embeds.repeat( |
| | 1, block_state.num_videos_per_prompt, 1 |
| | ) |
| | block_state.prompt_embeds = block_state.prompt_embeds.view( |
| | block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 |
| | ) |
| |
|
| | if block_state.negative_prompt_embeds is not None: |
| | _, seq_len, _ = block_state.negative_prompt_embeds.shape |
| | block_state.negative_prompt_embeds = ( |
| | block_state.negative_prompt_embeds.repeat( |
| | 1, block_state.num_videos_per_prompt, 1 |
| | ) |
| | ) |
| | block_state.negative_prompt_embeds = ( |
| | block_state.negative_prompt_embeds.view( |
| | block_state.batch_size * block_state.num_videos_per_prompt, |
| | seq_len, |
| | -1, |
| | ) |
| | ) |
| |
|
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |
| |
|
| |
|
| | class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks): |
| | model_name = "WanRT" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec("scheduler", UniPCMultistepScheduler), |
| | ] |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Step that sets the scheduler's timesteps for inference" |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam("num_inference_steps", default=4), |
| | InputParam("timesteps"), |
| | InputParam("sigmas"), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "timesteps", |
| | type_hint=torch.Tensor, |
| | description="The timesteps to use for inference", |
| | ), |
| | OutputParam( |
| | "all_timesteps", |
| | type_hint=torch.Tensor, |
| | description="The timesteps to use for inference", |
| | ), |
| | OutputParam( |
| | "num_inference_steps", |
| | type_hint=int, |
| | description="The number of denoising steps to perform at inference time", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | block_state.device = components._execution_device |
| |
|
| | shift = 5.0 |
| | sigmas = torch.linspace(1.0, 0.0, 1001)[:-1] |
| | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
| |
|
| | timesteps = sigmas.to(components.transformer.device) * 1000.0 |
| | zero_padded_timesteps = torch.cat( |
| | [ |
| | timesteps, |
| | torch.tensor([0], device=components.transformer.device), |
| | ] |
| | ) |
| | denoising_steps = torch.linspace( |
| | 1000, 0, block_state.num_inference_steps, dtype=torch.float32 |
| | ).to(torch.long) |
| |
|
| | block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps] |
| | block_state.all_timesteps = timesteps |
| | block_state.sigmas = sigmas |
| |
|
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |
| |
|
| |
|
| | class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks): |
| | model_name = "WanRT" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec("vae", AutoencoderKLWan), |
| | ] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [ConfigSpec("num_frames_per_block", 3)] |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Prepare latents step that prepares the latents for the text-to-video generation process" |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam("height", type_hint=int), |
| | InputParam("width", type_hint=int), |
| | InputParam("num_blocks", type_hint=int), |
| | InputParam("num_frames_per_block", type_hint=int), |
| | InputParam("latents", type_hint=Optional[torch.Tensor]), |
| | InputParam("init_latents", type_hint=Optional[torch.Tensor]), |
| | InputParam("final_latents", type_hint=Optional[torch.Tensor]), |
| | InputParam("num_videos_per_prompt", type_hint=int, default=1), |
| | InputParam("generator"), |
| | InputParam( |
| | "dtype", |
| | type_hint=torch.dtype, |
| | description="The dtype of the model inputs", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "latents", |
| | type_hint=torch.Tensor, |
| | description="The initial latents to use for the denoising process", |
| | ), |
| | OutputParam( |
| | "init_latents", |
| | type_hint=torch.Tensor, |
| | description="The initial latents to use for the denoising process", |
| | ), |
| | OutputParam( |
| | "final_latents", |
| | type_hint=torch.Tensor, |
| | ), |
| | ] |
| |
|
| | @staticmethod |
| | def check_inputs(components, block_state): |
| | if ( |
| | block_state.height is not None |
| | and block_state.height % components.vae_scale_factor_spatial != 0 |
| | ) or ( |
| | block_state.width is not None |
| | and block_state.width % components.vae_scale_factor_spatial != 0 |
| | ): |
| | raise ValueError( |
| | f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." |
| | ) |
| |
|
| | @staticmethod |
| | def prepare_latents( |
| | components, |
| | batch_size: int, |
| | num_channels_latents: int = 16, |
| | height: int = 352, |
| | width: int = 640, |
| | num_blocks: int = 9, |
| | num_frames_per_block: int = 3, |
| | dtype: Optional[torch.dtype] = None, |
| | device: Optional[torch.device] = None, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | if latents is not None: |
| | return latents.to(device=device, dtype=dtype) |
| |
|
| | num_latent_frames = num_blocks * num_frames_per_block |
| | shape = ( |
| | batch_size, |
| | num_channels_latents, |
| | num_latent_frames, |
| | int(height) // components.vae_scale_factor_spatial, |
| | int(width) // components.vae_scale_factor_spatial, |
| | ) |
| | if isinstance(generator, list) and len(generator) != batch_size: |
| | raise ValueError( |
| | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| | f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| | ) |
| |
|
| | latents = randn_tensor( |
| | shape, |
| | generator=generator, |
| | device=components.transformer.device, |
| | dtype=dtype, |
| | ) |
| | return latents |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| |
|
| | block_state.height = block_state.height or components.default_height |
| | block_state.width = block_state.width or components.default_width |
| | block_state.device = components._execution_device |
| | block_state.num_channels_latents = components.num_channels_latents |
| |
|
| | self.check_inputs(components, block_state) |
| |
|
| | block_state.init_latents = self.prepare_latents( |
| | components, |
| | 1, |
| | block_state.num_channels_latents, |
| | block_state.height, |
| | block_state.width, |
| | block_state.num_blocks, |
| | components.config.num_frames_per_block, |
| | components.transformer.dtype, |
| | block_state.device, |
| | block_state.generator, |
| | block_state.init_latents, |
| | ) |
| | if block_state.final_latents is None: |
| | block_state.final_latents = torch.zeros_like( |
| | block_state.init_latents, device=components.transformer.device |
| | ) |
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |
| |
|
| |
|
| | class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks): |
| | """ |
| | Extracts a single block of latents from the full video buffer for streaming generation. |
| | |
| | This block simply slices the final_latents buffer to get the current block's latents. |
| | The final_latents buffer should be created beforehand using WanRTStreamingPrepareAllLatents. |
| | """ |
| |
|
| | model_name = "WanRT" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [] |
| |
|
| | @property |
| | def description(self) -> str: |
| | return ( |
| | "Extracts a single block from the full latent buffer for streaming generation. " |
| | "Slices final_latents based on block_idx to get current block's latents." |
| | ) |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "final_latents", |
| | required=True, |
| | type_hint=torch.Tensor, |
| | description="Full latent buffer [B, C, total_frames, H, W]", |
| | ), |
| | InputParam( |
| | "init_latents", |
| | required=True, |
| | type_hint=torch.Tensor, |
| | description="Full latent buffer [B, C, total_frames, H, W]", |
| | ), |
| | InputParam( |
| | "latents", |
| | type_hint=torch.Tensor, |
| | description="Full latent buffer [B, C, total_frames, H, W]", |
| | ), |
| | InputParam( |
| | "block_idx", |
| | required=True, |
| | type_hint=int, |
| | default=0, |
| | description="Current block index to process", |
| | ), |
| | InputParam( |
| | "num_frames_per_block", |
| | required=True, |
| | type_hint=int, |
| | default=3, |
| | description="Number of frames per block", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "latents", |
| | type_hint=torch.Tensor, |
| | description="Latents for current block [B, C, num_frames_per_block, H, W]", |
| | ), |
| | OutputParam( |
| | "current_start_frame", |
| | type_hint=int, |
| | description="Starting frame index for current block", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| |
|
| | num_frames_per_block = block_state.num_frames_per_block |
| | block_idx = block_state.block_idx |
| |
|
| | |
| | start_frame = block_idx * num_frames_per_block |
| | end_frame = start_frame + num_frames_per_block |
| |
|
| | |
| | |
| | |
| | block_state.latents = block_state.init_latents[ |
| | :, :, start_frame:end_frame, :, : |
| | ] |
| | block_state.current_start_frame = start_frame |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|
| |
|
| | class WanRTStreamingSetupKVCache(ModularPipelineBlocks): |
| | """ |
| | Initializes KV cache and cross-attention cache for streaming generation. |
| | |
| | This block sets up the persistent caches used across all blocks in streaming |
| | generation. Mirrors the cache initialization logic from causal_inference.py. |
| | Should be called once at the start of streaming generation. |
| | """ |
| |
|
| | model_name = "WanRT" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec("transformer", torch.nn.Module), |
| | ] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [ |
| | ConfigSpec("kv_cache_num_frames", 3), |
| | ConfigSpec("num_frames_per_block", 3), |
| | ConfigSpec("frame_seq_length", 1560), |
| | ConfigSpec("frame_cache_len", 9), |
| | ] |
| |
|
| | @property |
| | def description(self) -> str: |
| | return ( |
| | "Initializes KV cache and cross-attention cache for streaming generation. " |
| | "Creates persistent caches that will be reused across all blocks." |
| | ) |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "kv_cache", |
| | required=False, |
| | type_hint=Optional[List[Dict]], |
| | description="Existing KV cache. If provided and shape matches, will be zeroed instead of recreated.", |
| | ), |
| | InputParam( |
| | "crossattn_cache", |
| | required=False, |
| | type_hint=Optional[List[Dict]], |
| | description="Existing cross-attention cache. If provided and shape matches, will be zeroed.", |
| | ), |
| | InputParam( |
| | "local_attn_size", |
| | required=False, |
| | type_hint=int, |
| | default=-1, |
| | description="Local attention size for computing KV cache size. -1 uses default (32760).", |
| | ), |
| | InputParam( |
| | "dtype", |
| | required=False, |
| | type_hint=torch.dtype, |
| | description="Data type for caches (defaults to bfloat16)", |
| | ), |
| | InputParam( |
| | "update_prompt_embeds", |
| | required=False, |
| | description="Flag to reinitialize prompt embeds if they are updated.", |
| | default=False, |
| | ), |
| | ] |
| |
|
| | @property |
| | def outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "kv_cache", |
| | type_hint=List[Dict], |
| | description="Initialized KV cache (list of dicts per transformer block)", |
| | ), |
| | OutputParam( |
| | "crossattn_cache", |
| | type_hint=List[Dict], |
| | description="Initialized cross-attention cache", |
| | ), |
| | OutputParam( |
| | "local_attn_size", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | batch_size = 1 |
| |
|
| | |
| | kv_cache = block_state.kv_cache |
| | crossattn_cache = block_state.crossattn_cache |
| |
|
| | if block_state.crossattn_cache is None or block_state.update_prompt_embeds: |
| | block_state.crossattn_cache = _initialize_crossattn_cache( |
| | components, |
| | crossattn_cache, |
| | batch_size, |
| | components.transformer.dtype, |
| | components.transformer.device, |
| | ) |
| |
|
| | block_state.local_attn_size = ( |
| | components.config.kv_cache_num_frames |
| | + components.config.num_frames_per_block |
| | ) |
| | for block in components.transformer.blocks: |
| | block.self_attn.local_attn_size = -1 |
| | for block in components.transformer.blocks: |
| | block.self_attn.num_frame_per_block = components.config.num_frames_per_block |
| |
|
| | block_state.kv_cache = _initialize_kv_cache( |
| | components, |
| | kv_cache, |
| | batch_size, |
| | components.transformer.dtype, |
| | components.transformer.device, |
| | block_state.local_attn_size, |
| | components.config.frame_seq_length, |
| | ) |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|
| |
|
| | class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks): |
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "latents", |
| | type_hint=torch.Tensor, |
| | description="Current block latents [B, C, num_frames_per_block, H, W]", |
| | ), |
| | InputParam( |
| | "num_frames_per_block", |
| | type_hint=int, |
| | description="Number of frames per block", |
| | ), |
| | InputParam( |
| | "block_idx", |
| | type_hint=int, |
| | description="Current block index to process", |
| | ), |
| | InputParam( |
| | "block_mask", |
| | description="Block-wise causal attention mask", |
| | ), |
| | InputParam( |
| | "current_start_frame", |
| | type_hint=int, |
| | description="Starting frame index for current block", |
| | ), |
| | InputParam( |
| | "videos", |
| | type_hint=torch.Tensor, |
| | description="Video frames for context encoding", |
| | ), |
| | InputParam( |
| | "final_latents", |
| | type_hint=torch.Tensor, |
| | description="Full latent buffer [B, C, total_frames, H, W]", |
| | ), |
| | InputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | description="Text embeddings to guide generation", |
| | ), |
| | InputParam( |
| | "kv_cache", |
| | type_hint=torch.Tensor, |
| | description="Key-value cache for attention", |
| | ), |
| | InputParam( |
| | "crossattn_cache", |
| | type_hint=torch.Tensor, |
| | description="Cross-attention cache", |
| | ), |
| | InputParam( |
| | "encoder_cache", |
| | description="Encoder feature cache", |
| | ), |
| | InputParam( |
| | "frame_cache_context", |
| | description="Cached context frames for reencoding", |
| | ), |
| | InputParam( |
| | "local_attn_size", |
| | ), |
| | ] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [ConfigSpec("seq_length", 32760)] |
| |
|
| | def prepare_latents(self, components, block_state): |
| | frames = block_state.frame_cache_context[0].half() |
| |
|
| | components.vae._enc_feat_map = [None] * 55 |
| | latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax") |
| | latents_mean = ( |
| | torch.tensor(components.vae.config.latents_mean) |
| | .view(1, components.vae.config.z_dim, 1, 1, 1) |
| | .to(latents.device, latents.dtype) |
| | ) |
| | latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( |
| | 1, components.vae.config.z_dim, 1, 1, 1 |
| | ).to(latents.device, latents.dtype) |
| | latents = (latents - latents_mean) * latents_std |
| |
|
| | return latents.to(components.transformer.dtype) |
| |
|
| | def get_context_frames(self, components, block_state): |
| | current_kv_cache_num_frames = components.config.kv_cache_num_frames |
| | context_frames = block_state.final_latents[ |
| | :, :, : block_state.current_start_frame |
| | ] |
| |
|
| | if ( |
| | block_state.block_idx - 1 |
| | ) * block_state.num_frames_per_block < current_kv_cache_num_frames: |
| | if current_kv_cache_num_frames == 1: |
| | context_frames = context_frames[:, :, :1] |
| | else: |
| | context_frames = torch.cat( |
| | ( |
| | context_frames[:, :, :1], |
| | context_frames[:, :, 1:][ |
| | :, :, -current_kv_cache_num_frames + 1 : |
| | ], |
| | ), |
| | dim=2, |
| | ) |
| | else: |
| | context_frames = context_frames[:, :, 1:][ |
| | :, :, -current_kv_cache_num_frames + 1 : |
| | ] |
| | first_frame_latent = self.prepare_latents(components, block_state) |
| | first_frame_latent = first_frame_latent.to(block_state.latents) |
| | context_frames = torch.cat((first_frame_latent, context_frames), dim=2) |
| |
|
| | return context_frames |
| |
|
| | def __call__(self, components, state): |
| | block_state = self.get_block_state(state) |
| | if block_state.block_idx == 0: |
| | return components, state |
| |
|
| | start_frame = min( |
| | block_state.current_start_frame, components.config.kv_cache_num_frames |
| | ) |
| | context_frames = self.get_context_frames(components, block_state) |
| | block_state.block_mask = ( |
| | components.transformer._prepare_blockwise_causal_attn_mask( |
| | components.transformer.device, |
| | num_frames=context_frames.shape[2], |
| | frame_seqlen=components.config.frame_seq_length, |
| | num_frame_per_block=block_state.num_frames_per_block, |
| | local_attn_size=-1, |
| | ) |
| | ) |
| | components.transformer.block_mask = block_state.block_mask |
| | context_timestep = torch.zeros( |
| | (context_frames.shape[0], context_frames.shape[2]), |
| | device=components.transformer.device, |
| | dtype=torch.int64, |
| | ) |
| | components.transformer( |
| | x=context_frames.to(components.transformer.dtype), |
| | t=context_timestep, |
| | context=block_state.prompt_embeds.to(components.transformer.dtype), |
| | kv_cache=block_state.kv_cache, |
| | seq_len=components.config.seq_length, |
| | crossattn_cache=block_state.crossattn_cache, |
| | current_start=start_frame * components.config.frame_seq_length, |
| | cache_start=None, |
| | ) |
| | components.transformer.block_mask = None |
| |
|
| | return components, state |
| |
|
| |
|
| | class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks): |
| | block_classes = [ |
| | WanRTStreamingSetTimestepsStep, |
| | WanRTStreamingPrepareLatentsStep, |
| | WanRTStreamingExtractBlockLatentsStep, |
| | WanRTStreamingSetupKVCache, |
| | WanRTStreamingRecomputeKVCache, |
| | ] |
| | block_names = [ |
| | "set_timesteps", |
| | "prepare_latents", |
| | "extract_block_init_latents", |
| | "setup_kv_cache", |
| | "recompute_kv_cache", |
| | ] |
| |
|
| | @property |
| | def description(self): |
| | return ( |
| | "Before denoise step that prepare the inputs for the denoise step.\n" |
| | + "This is a sequential pipeline blocks:\n" |
| | + " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n" |
| | + " - `WanRTSetTimestepsStep` is used to set the timesteps\n" |
| | + " - `WanRTPrepareLatentsStep` is used to prepare the latents\n" |
| | ) |
| |
|