YiYiXu's picture
Upload folder using huggingface_hub
46a7fb0 verified
raw
history blame
1.51 kB
from typing import List
import torch
from diffusers.pipelines.modular_pipeline import PipelineState, PipelineBlock
from diffusers.pipelines.modular_pipeline_utils import (
InputParam,
ComponentSpec,
OutputParam,
)
from image_gen_aux import DepthPreprocessor
class DepthProcessorBlock(PipelineBlock):
@property
def expected_components(self):
return [
ComponentSpec(
name="depth_processor",
type_hint=DepthPreprocessor,
repo="depth-anything/Depth-Anything-V2-Large-hf",
)
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
PipelineImageInput,
required=True,
description="Image(s) to use to extract depth maps",
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image",
type_hint=torch.Tensor,
description="Depth Map(s) of input Image(s)",
),
]
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image = block_state.image
depth_map = pipeline.depth_processor(image)
block_state.image = depth_map.to(block_state.device)
self.add_block_state(state, block_state)
return pipeline, state