blur2vid / training /utils.py
ftaubner's picture
initial commit
7245cc5
raw
history blame
10 kB
import os
from typing import List, Optional, Union, Tuple
import torch
from transformers import T5EncoderModel, T5Tokenizer
import numpy as np
import cv2
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from accelerate.logging import get_logger
import tempfile
import argparse
import yaml
import shutil
logger = get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the YAML config file."
)
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
args = argparse.Namespace(**config)
# Convert nested config dict to an argparse.Namespace for easier downstream usage
return args
def atomic_save(save_path, accelerator):
parent = os.path.dirname(save_path)
tmp_dir = tempfile.mkdtemp(dir=parent)
backup_dir = save_path + "_backup"
try:
# Save state into the temp directory
accelerator.save_state(tmp_dir)
# Backup existing save_path if it exists
if os.path.exists(save_path):
os.rename(save_path, backup_dir)
# Atomically move temp directory into place
os.rename(tmp_dir, save_path)
# Clean up the backup directory
if os.path.exists(backup_dir):
shutil.rmtree(backup_dir)
except Exception as e:
# Clean up temp directory on failure
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
# Restore from backup if replacement failed
if os.path.exists(backup_dir):
if os.path.exists(save_path):
shutil.rmtree(save_path)
os.rename(backup_dir, save_path)
raise e
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if args.optimizer.lower() == "adamw":
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "adam":
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
return optimizer
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def _get_t5_prompt_embeds(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = _get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
text_input_ids=text_input_ids,
)
return prompt_embeds
def compute_prompt_embeddings(
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
if requires_grad:
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
else:
with torch.no_grad():
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds
def save_frames_as_pngs(video_array,output_dir,
downsample_spatial=1, # e.g. 2 to halve width & height
downsample_temporal=1): # e.g. 2 to keep every 2nd frame
"""
Save each frame of a (T, H, W, C) numpy array as a PNG with no compression.
"""
assert video_array.ndim == 4 and video_array.shape[-1] == 3, \
"Expected (T, H, W, C=3) array"
assert video_array.dtype == np.uint8, "Expected uint8 array"
os.makedirs(output_dir, exist_ok=True)
# temporal downsample
frames = video_array[::downsample_temporal]
# compute spatially downsampled size
T, H, W, _ = frames.shape
new_size = (W // downsample_spatial, H // downsample_spatial)
# PNG compression param: 0 = no compression
png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0]
for idx, frame in enumerate(frames):
# frame is RGB; convert to BGR for OpenCV
bgr = frame[..., ::-1]
if downsample_spatial > 1:
bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST)
filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx))
success = cv2.imwrite(filename, bgr, png_params)
if not success:
raise RuntimeError("Failed to write frame ")