## WanVideoAnimateEmbeds
```python=
class WanVideoAnimateEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
"force_offload": ("BOOLEAN", {"default": True}),
"frame_window_size": ("INT", {"default": 77, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of frames to use for temporal attention window"}),
"colormatch": (
[
'disabled',
'mkl',
'hm',
'reinhard',
'mvgd',
'hm-mvgd-hm',
'hm-mkl-hm',
], {
"default": 'disabled', "tooltip": "Color matching method to use between the windows"
},),
"pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the pose"}),
"face_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the face"}),
},
"optional": {
"clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}),
"ref_images": ("IMAGE", {"tooltip": "Image to encode"}),
"pose_images": ("IMAGE", {"tooltip": "end frame"}),
"face_images": ("IMAGE", {"tooltip": "end frame"}),
"bg_images": ("IMAGE", {"tooltip": "background images"}),
"mask": ("MASK", {"tooltip": "mask"}),
"tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "process"
CATEGORY = "pf_custom_nodes/WanVideoWrapper"
def process(self, vae, width, height, num_frames, force_offload, frame_window_size, colormatch, pose_strength, face_strength,
ref_images=None, pose_images=None, face_images=None, clip_embeds=None, tiled_vae=False, bg_images=None, mask=None):
H = height
W = width
lat_h = H // vae.upsampling_factor
lat_w = W // vae.upsampling_factor
num_refs = ref_images.shape[0] if ref_images is not None else 0
num_frames = ((num_frames - 1) // 4) * 4 + 1
looping = num_frames > frame_window_size
if num_frames < frame_window_size:
frame_window_size = num_frames
target_shape = (16, (num_frames - 1) // 4 + 1 + num_refs, lat_h, lat_w)
latent_window_size = ((frame_window_size - 1) // 4)
if not looping:
num_frames = num_frames + num_refs * 4
else:
latent_window_size = latent_window_size + 1
mm.soft_empty_cache()
gc.collect()
vae.to(device)
# Resize and rearrange the input image dimensions
pose_latents = ref_latents = ref_latent = None
if pose_images is not None:
pose_images = pose_images[..., :3]
if pose_images.shape[1] != H or pose_images.shape[2] != W:
resized_pose_images = common_upscale(pose_images.movedim(-1, 1), W, H, "lanczos", "center").movedim(0, 1)
else:
resized_pose_images = pose_images.permute(3, 0, 1, 2) # C, T, H, W
resized_pose_images = resized_pose_images * 2 - 1
if not looping:
pose_latents = vae.encode([resized_pose_images.to(device, vae.dtype)], device,tiled=tiled_vae)
pose_latents = pose_latents.to(offload_device)
if pose_latents.shape[2] < latent_window_size:
print(f"WanAnimate: Padding pose latents from {pose_latents.shape} to length {latent_window_size}")
pad_len = latent_window_size - pose_latents.shape[2]
pad = torch.zeros(pose_latents.shape[0], pose_latents.shape[1], pad_len, pose_latents.shape[3], pose_latents.shape[4], device=pose_latents.device, dtype=pose_latents.dtype)
pose_latents = torch.cat([pose_latents, pad], dim=2)
del resized_pose_images
else:
resized_pose_images = resized_pose_images.to(offload_device, dtype=vae.dtype)
bg_latents = None
if bg_images is not None:
if bg_images.shape[1] != H or bg_images.shape[2] != W:
resized_bg_images = common_upscale(bg_images.movedim(-1, 1), W, H, "lanczos", "center").movedim(0, 1)
else:
resized_bg_images = bg_images.permute(3, 0, 1, 2) # C, T, H, W
resized_bg_images = (resized_bg_images[:3] * 2 - 1)
if not looping:
if bg_images is None:
resized_bg_images = torch.zeros(3, num_frames - num_refs, H, W, device=device, dtype=vae.dtype)
bg_latents = vae.encode([resized_bg_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0].to(offload_device)
del resized_bg_images
elif bg_images is not None:
resized_bg_images = resized_bg_images.to(offload_device, dtype=vae.dtype)
if ref_images is not None:
if ref_images.shape[1] != H or ref_images.shape[2] != W:
resized_ref_images = common_upscale(ref_images.movedim(-1, 1), W, H, "lanczos", "center").movedim(0, 1)
else:
resized_ref_images = ref_images.permute(3, 0, 1, 2) # C, T, H, W
resized_ref_images = resized_ref_images[:3] * 2 - 1
ref_latent = vae.encode([resized_ref_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0]
msk = torch.zeros(4, 1, lat_h, lat_w, device=device, dtype=vae.dtype)
msk[:, :num_refs] = 1
import logging
logging.info(f"WanVideoAnimateEmbeds: ref_latent shape {ref_latent.shape}")
logging.info(f"WanVideoAnimateEmbeds: msk shape {msk.shape}")
ref_latent_masked = torch.cat([msk, ref_latent], dim=0).to(offload_device) # 4+C 1 H W
if mask is None:
bg_mask = torch.zeros(1, num_frames, lat_h, lat_w, device=offload_device, dtype=vae.dtype)
else:
bg_mask = 1 - mask[:num_frames]
if bg_mask.shape[0] < num_frames and not looping:
bg_mask = torch.cat([bg_mask, bg_mask[-1:].repeat(num_frames - bg_mask.shape[0], 1, 1)], dim=0)
bg_mask = common_upscale(bg_mask.unsqueeze(1), lat_w, lat_h, "nearest", "center").squeeze(1)
bg_mask = bg_mask.unsqueeze(-1).permute(3, 0, 1, 2).to(offload_device, vae.dtype) # C, T, H, W
if bg_images is None and looping:
bg_mask[:, :num_refs] = 1
bg_mask_mask_repeated = torch.repeat_interleave(bg_mask[:, 0:1], repeats=4, dim=1) # T, C, H, W
bg_mask = torch.cat([bg_mask_mask_repeated, bg_mask[:, 1:]], dim=1)
bg_mask = bg_mask.view(1, bg_mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W
bg_mask = bg_mask.movedim(1, 2)[0]# C, T, H, W
if not looping:
bg_latents_masked = torch.cat([bg_mask[:, :bg_latents.shape[1]], bg_latents], dim=0)
del bg_mask, bg_latents
ref_latent = torch.cat([ref_latent_masked, bg_latents_masked], dim=1)
else:
ref_latent = ref_latent_masked
if face_images is not None:
face_images = face_images[..., :3]
if face_images.shape[1] != 512 or face_images.shape[2] != 512:
resized_face_images = common_upscale(face_images.movedim(-1, 1), 512, 512, "lanczos", "center").movedim(0, 1)
else:
resized_face_images = face_images.permute(3, 0, 1, 2) # B, C, T, H, W
resized_face_images = (resized_face_images * 2 - 1).unsqueeze(0)
resized_face_images = resized_face_images.to(offload_device, dtype=vae.dtype)
seq_len = math.ceil((target_shape[2] * target_shape[3]) / 4 * target_shape[1])
if force_offload:
vae.model.to(offload_device)
mm.soft_empty_cache()
gc.collect()
image_embeds = {
"clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None,
"negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None,
"max_seq_len": seq_len,
"pose_latents": pose_latents,
"pose_images": resized_pose_images if pose_images is not None and looping else None,
"bg_images": resized_bg_images if bg_images is not None and looping else None,
"ref_masks": bg_mask if mask is not None and looping else None,
"is_masked": mask is not None,
"ref_latent": ref_latent,
"ref_image": resized_ref_images if ref_images is not None else None,
"face_pixels": resized_face_images if face_images is not None else None,
"num_frames": num_frames,
"target_shape": target_shape,
"frame_window_size": frame_window_size,
"lat_h": lat_h,
"lat_w": lat_w,
"vae": vae,
"colormatch": colormatch,
"looping": looping,
"pose_strength": pose_strength,
"face_strength": face_strength,
}
return (image_embeds,)
```
## WanVideoSampler
```python=
# region wananimate loop
elif wananimate_loop:
# calculate frame counts
total_frames = num_frames
refert_num = 1
real_clip_len = frame_window_size - refert_num
last_clip_num = (total_frames - refert_num) % real_clip_len
extra = 0 if last_clip_num == 0 else real_clip_len - last_clip_num
target_len = total_frames + extra
estimated_iterations = target_len // real_clip_len
target_latent_len = (target_len - 1) // 4 + estimated_iterations
latent_window_size = (frame_window_size - 1) // 4 + 1
from .utils import tensor_pingpong_pad
ref_latent = image_embeds.get("ref_latent", None)
ref_images = image_embeds.get("ref_image", None)
bg_images = image_embeds.get("bg_images", None)
pose_images = image_embeds.get("pose_images", None)
current_ref_images = face_images = face_images_in = None
if wananim_face_pixels is not None:
face_images = tensor_pingpong_pad(wananim_face_pixels, target_len)
log.info(f"WanAnimate: Face input {wananim_face_pixels.shape} padded to shape {face_images.shape}")
if wananim_ref_masks is not None:
ref_masks_in = tensor_pingpong_pad(wananim_ref_masks, target_latent_len)
log.info(f"WanAnimate: Ref masks {wananim_ref_masks.shape} padded to shape {ref_masks_in.shape}")
if bg_images is not None:
bg_images_in = tensor_pingpong_pad(bg_images, target_len)
log.info(f"WanAnimate: BG images {bg_images.shape} padded to shape {bg_images.shape}")
if pose_images is not None:
pose_images_in = tensor_pingpong_pad(pose_images, target_len)
log.info(f"WanAnimate: Pose images {pose_images.shape} padded to shape {pose_images_in.shape}")
# init variables
offloaded = False
colormatch = image_embeds.get("colormatch", "disabled")
output_path = image_embeds.get("output_path", "")
offload = image_embeds.get("force_offload", False)
lat_h, lat_w = noise.shape[2], noise.shape[3]
start = start_latent = img_counter = step_iteration_count = iteration_count = 0
end = frame_window_size
end_latent = latent_window_size
callback = prepare_callback(patcher, estimated_iterations)
log.info(f"Sampling {total_frames} frames in {estimated_iterations} windows, at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps")
# outer WanAnimate loop
gen_video_list = []
while True:
if start + refert_num >= total_frames:
break
mm.soft_empty_cache()
mask_reft_len = 0 if start == 0 else refert_num
self.cache_state = [None, None]
noise = torch.randn(16, latent_window_size + 1, lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device)
seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1])
if current_ref_images is not None or bg_images is not None or ref_latent is not None:
if offload:
offload_transformer(transformer, remove_lora=False)
offloaded = True
vae.to(device)
if wananim_ref_masks is not None:
msk = ref_masks_in[:, start_latent:end_latent].to(device, dtype)
else:
msk = torch.zeros(4, latent_window_size, lat_h, lat_w, device=device, dtype=dtype)
if bg_images is not None:
bg_image_slice = bg_images_in[:, start:end].to(device)
else:
bg_image_slice = torch.zeros(3, frame_window_size-refert_num, lat_h * 8, lat_w * 8, device=device, dtype=vae.dtype)
if mask_reft_len == 0:
temporal_ref_latents = vae.encode([bg_image_slice], device,tiled=tiled_vae)[0]
else:
concatenated = torch.cat([current_ref_images.to(device, dtype=vae.dtype), bg_image_slice[:, mask_reft_len:]], dim=1)
temporal_ref_latents = vae.encode([concatenated.to(device, vae.dtype)], device,tiled=tiled_vae, pbar=False)[0]
msk[:, :mask_reft_len] = 1
if msk.shape[1] != temporal_ref_latents.shape[1]:
if temporal_ref_latents.shape[1] < msk.shape[1]:
pad_len = msk.shape[1] - temporal_ref_latents.shape[1]
pad_tensor = temporal_ref_latents[:, -1:].repeat(1, pad_len, 1, 1)
temporal_ref_latents = torch.cat([temporal_ref_latents, pad_tensor], dim=1)
else:
temporal_ref_latents = temporal_ref_latents[:, :msk.shape[1]]
if ref_latent is not None:
temporal_ref_latents = torch.cat([msk, temporal_ref_latents], dim=0) # 4+C T H W
image_cond_in = torch.cat([ref_latent.to(device), temporal_ref_latents], dim=1) # 4+C T+trefs H W
del temporal_ref_latents, msk, bg_image_slice
else:
image_cond_in = torch.cat([torch.tile(torch.zeros_like(noise[:1]), [4, 1, 1, 1]), torch.zeros_like(noise)], dim=0).to(device)
else:
image_cond_in = torch.cat([torch.tile(torch.zeros_like(noise[:1]), [4, 1, 1, 1]), torch.zeros_like(noise)], dim=0).to(device)
pose_input_slice = None
if pose_images is not None:
vae.to(device)
pose_image_slice = pose_images_in[:, start:end].to(device)
pose_input_slice = vae.encode([pose_image_slice], device,tiled=tiled_vae, pbar=False).to(dtype)
vae.to(offload_device)
if wananim_face_pixels is None and wananim_ref_masks is not None:
face_images_in = torch.zeros(1, 3, frame_window_size, 512, 512, device=device, dtype=torch.float32)
elif wananim_face_pixels is not None:
face_images_in = face_images[:, :, start:end].to(device, torch.float32) if face_images is not None else None
if samples is not None:
input_samples = samples["samples"]
if input_samples is not None:
input_samples = input_samples.squeeze(0).to(noise)
# Check if we have enough frames in input_samples
if latent_end_idx > input_samples.shape[1]:
# We need more frames than available - pad the input_samples at the end
pad_length = latent_end_idx - input_samples.shape[1]
last_frame = input_samples[:, -1:].repeat(1, pad_length, 1, 1)
input_samples = torch.cat([input_samples, last_frame], dim=1)
input_samples = input_samples[:, latent_start_idx:latent_end_idx]
if noise_mask is not None:
original_image = input_samples.to(device)
assert input_samples.shape[1] == noise.shape[1], f"Slice mismatch: {input_samples.shape[1]} vs {noise.shape[1]}"
if add_noise_to_samples:
latent_timestep = timesteps[0]
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
else:
noise = input_samples
# diff diff prep
noise_mask = samples.get("noise_mask", None)
if noise_mask is not None:
if len(noise_mask.shape) == 4:
noise_mask = noise_mask.squeeze(1)
if noise_mask.shape[0] < noise.shape[1]:
noise_mask = noise_mask.repeat(noise.shape[1] // noise_mask.shape[0], 1, 1)
else:
noise_mask = noise_mask[start_latent:end_latent]
noise_mask = torch.nn.functional.interpolate(
noise_mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W]
size=(noise.shape[1], noise.shape[2], noise.shape[3]),
mode='trilinear',
align_corners=False
).repeat(1, noise.shape[0], 1, 1, 1)
thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps)
thresholds = thresholds.reshape(-1, 1, 1, 1, 1).to(device)
masks = (1-noise_mask.repeat(len(timesteps), 1, 1, 1, 1).to(device)) > thresholds
if isinstance(scheduler, dict):
sample_scheduler = copy.deepcopy(scheduler["sample_scheduler"])
timesteps = scheduler["timesteps"]
else:
sample_scheduler, timesteps,_,_ = get_scheduler(scheduler, total_steps, start_step, end_step, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas)
# sample videos
latent = noise
if offloaded:
# Load weights
if transformer.patched_linear and gguf_reader is None:
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
elif gguf_reader is not None: #handle GGUF
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
#blockswap init
init_blockswap(transformer, block_swap_args, model)
# Use the appropriate prompt for this section
if len(text_embeds["prompt_embeds"]) > 1:
prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1)
positive = [text_embeds["prompt_embeds"][prompt_index]]
log.info(f"Using prompt index: {prompt_index}")
else:
positive = text_embeds["prompt_embeds"]
# uni3c slices
uni3c_data_input = None
if uni3c_embeds is not None:
render_latent = uni3c_embeds["render_latent"][:,:,start_latent:end_latent].to(device)
if render_latent.shape[2] < noise.shape[1]:
render_latent = torch.nn.functional.interpolate(render_latent, size=(noise.shape[1], noise.shape[2], noise.shape[3]), mode='trilinear', align_corners=False)
uni3c_data_input = {"render_latent": render_latent}
for k in uni3c_data:
if k != "render_latent":
uni3c_data_input[k] = uni3c_data[k]
mm.soft_empty_cache()
gc.collect()
# inner WanAnimate sampling loop
sampling_pbar = tqdm(total=len(timesteps), desc=f"Frames {start}-{end}", position=0, leave=True)
for i in range(len(timesteps)):
timestep = timesteps[i]
latent_model_input = latent.to(device)
noise_pred, _, self.cache_state = predict_with_cfg(
latent_model_input, cfg[min(i, len(timesteps)-1)], positive, text_embeds["negative_prompt_embeds"],
timestep, i, cache_state=self.cache_state, image_cond=image_cond_in, clip_fea=clip_fea, wananim_face_pixels=face_images_in,
wananim_pose_latents=pose_input_slice, uni3c_data=uni3c_data_input,
)
if callback is not None:
callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3)
callback(step_iteration_count, callback_latent, None, estimated_iterations*(len(timesteps)))
del callback_latent
sampling_pbar.update(1)
step_iteration_count += 1
if use_tsr:
noise_pred = temporal_score_rescaling(noise_pred, latent, timestep, tsr_k, tsr_sigma)
latent = sample_scheduler.step(noise_pred.unsqueeze(0), timestep, latent.unsqueeze(0).to(noise_pred.device), **scheduler_step_args)[0].squeeze(0)
del noise_pred, latent_model_input, timestep
# differential diffusion inpaint
if masks is not None:
if i < len(timesteps) - 1:
image_latent = add_noise(original_image.to(device), noise.to(device), timesteps[i+1])
mask = masks[i].to(latent)
latent = image_latent * mask + latent * (1-mask)
del noise
if offload:
offload_transformer(transformer, remove_lora=False)
offloaded = True
vae.to(device)
videos = vae.decode(latent[:, 1:].unsqueeze(0).to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False)[0].cpu()
del latent
if start != 0:
videos = videos[:, refert_num:]
sampling_pbar.close()
# optional color correction
if colormatch != "disabled":
videos = videos.permute(1, 2, 3, 0).float().numpy()
from color_matcher import ColorMatcher
cm = ColorMatcher()
cm_result_list = []
for img in videos:
cm_result = cm.transfer(src=img, ref=ref_images.permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
cm_result_list.append(torch.from_numpy(cm_result).to(vae.dtype))
videos = torch.stack(cm_result_list, dim=0).permute(3, 0, 1, 2)
del cm_result_list
current_ref_images = videos[:, -refert_num:].clone().detach()
# optionally save generated samples to disk
if output_path:
video_np = videos.clamp(-1.0, 1.0).add(1.0).div(2.0).mul(255).cpu().float().numpy().transpose(1, 2, 3, 0).astype('uint8')
num_frames_to_save = video_np.shape[0] if is_first_clip else video_np.shape[0] - cur_motion_frames_num
log.info(f"Saving {num_frames_to_save} generated frames to {output_path}")
start_idx = 0 if is_first_clip else cur_motion_frames_num
for i in range(start_idx, video_np.shape[0]):
im = Image.fromarray(video_np[i])
im.save(os.path.join(output_path, f"frame_{img_counter:05d}.png"))
img_counter += 1
else:
gen_video_list.append(videos)
del videos
iteration_count += 1
start += frame_window_size - refert_num
end += frame_window_size - refert_num
start_latent += latent_window_size - ((refert_num - 1)// 4 + 1)
end_latent += latent_window_size - ((refert_num - 1)// 4 + 1)
if not output_path:
gen_video_samples = torch.cat(gen_video_list, dim=1)
else:
gen_video_samples = torch.zeros(3, 1, 64, 64) # dummy output
if force_offload:
vae.to(offload_device)
if not model["auto_cpu_offload"]:
offload_transformer(transformer)
try:
print_memory(device)
torch.cuda.reset_peak_memory_stats(device)
except:
pass
return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path},
```
## SuperWanVideoAnimateEmbeds
```python=
class SuperWanVideoAnimateEmbeds:
"""
Super node that chains multiple WanAnimate operations:
WanAnimatePreprocessor -> PoseAndFaceDetection -> DrawViTPose -> WanVideoAnimateEmbeds
This avoids intermediate caching by executing all operations in a single node.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
# WanAnimatePreprocessor inputs
"images": ("IMAGE",),
# PoseAndFaceDetection inputs
"model": ("POSEMODEL",),
# DrawViTPose inputs
"retarget_padding": ("INT", {"default": 16, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the retargeted pose image is padded and resized to the target size"}),
"body_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the body sticks. Set to 0 to disable body drawing, -1 for auto"}),
"hand_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the hand sticks. Set to 0 to disable hand drawing, -1 for auto"}),
"draw_head": ("BOOLEAN", {"default": True, "tooltip": "Whether to draw head keypoints"}),
# WanVideoAnimateEmbeds inputs
"vae": ("WANVAE",),
"force_offload": ("BOOLEAN", {"default": True}),
"frame_window_size": ("INT", {"default": 77, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of frames to use for temporal attention window"}),
"colormatch": (
[
'disabled',
'mkl',
'hm',
'reinhard',
'mvgd',
'hm-mvgd-hm',
'hm-mkl-hm',
], {
"default": 'disabled', "tooltip": "Color matching method to use between the windows"
},
),
"pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the pose"}),
"face_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the face"}),
},
"optional": {
# WanAnimatePreprocessor optional inputs
"is_mixing_mode": ("BOOLEAN", {"default": False, "tooltip": "Enable mixing mode"}),
"target_width": ("INT", {"default": 0, "min": 0, "max": 2048, "step": 8, "tooltip": "Target width (0 = auto)"}),
"target_height": ("INT", {"default": 0, "min": 0, "max": 2048, "step": 8, "tooltip": "Target height (0 = auto)"}),
"force_rate": ("INT", {"default": 16, "min": 1, "max": 60, "step": 1, "tooltip": "Frame rate"}),
"padding": ("BOOLEAN", {"default": True, "tooltip": "Enable padding"}),
# PoseAndFaceDetection optional inputs
"retarget_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}),
# WanVideoAnimateEmbeds optional inputs
"clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}),
"ref_images": ("IMAGE", {"tooltip": "Image to encode"}),
"bg_images": ("IMAGE", {"tooltip": "background images"}),
"mask": ("MASK", {"tooltip": "mask"}),
"tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "INT", "INT")
RETURN_NAMES = ("image_embeds", "num_padded_frames", "frame_count")
FUNCTION = "process"
CATEGORY = "custom"
DESCRIPTION = "Processes images through WanAnimatePreprocessor -> PoseAndFaceDetection -> DrawViTPose -> WanVideoAnimateEmbeds in a single node to avoid intermediate caching."
def process(self, model, images, retarget_padding, body_stick_width, hand_stick_width, draw_head,
vae, force_offload, frame_window_size, colormatch, pose_strength, face_strength,
is_mixing_mode=False, target_width=0, target_height=0, force_rate=16, padding=True,
retarget_image=None, clip_embeds=None, ref_images=None, bg_images=None, mask=None, tiled_vae=False):
"""
Execute the full pipeline without intermediate caching.
"""
log_container_memory_usage("[SuperWanVideoAnimateEmbeds - before process]")
### Step 0: WanAnimatePreprocessor
if "WanAnimatePreprocessor" in nodes.NODE_CLASS_MAPPINGS:
WanAnimatePreprocessorClass = nodes.NODE_CLASS_MAPPINGS["WanAnimatePreprocessor"]
preprocessor_obj = WanAnimatePreprocessorClass()
processed_images, num_padded_frames, width, height = preprocessor_obj.execute(
images, is_mixing_mode=is_mixing_mode, target_width=target_width,
target_height=target_height, force_rate=force_rate, padding=padding
)
images.set_(torch.empty(0))
# del preprocessor_obj
else:
raise ImportError("Could not find WanAnimatePreprocessor node")
log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 0: WanAnimatePreprocessor]")
# Calculate num_frames from processed images (which already includes padded frames)
num_frames = processed_images.shape[0]
frame_count = num_frames
### Step 1: PoseAndFaceDetection (using processed images and dimensions from WanAnimatePreprocessor)
if "PFPoseAndFaceDetection" in nodes.NODE_CLASS_MAPPINGS:
PFPoseAndFaceDetectionClass = nodes.NODE_CLASS_MAPPINGS["PFPoseAndFaceDetection"]
pose_detection_obj = PFPoseAndFaceDetectionClass()
pose_data, face_images, key_frame_body_points, bboxes, face_bboxes = pose_detection_obj.process(
model, processed_images, width, height, retarget_image
)
# del pose_detection_obj
# processed_images.set_(torch.empty(0))
# del processed_images
else:
raise ImportError("Could not find PFPoseAndFaceDetection node")
log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 1: PoseAndFaceDetection]")
### Step 2: DrawViTPose
if "PFDrawViTPose" in nodes.NODE_CLASS_MAPPINGS:
PFDrawViTPoseClass = nodes.NODE_CLASS_MAPPINGS["PFDrawViTPose"]
draw_pose_obj = PFDrawViTPoseClass()
pose_images = draw_pose_obj.process(
pose_data, width, height, body_stick_width, hand_stick_width, draw_head, retarget_padding
)[0]
# del draw_pose_obj
else:
raise ImportError("Could not find PFDrawViTPose node")
mm.soft_empty_cache()
gc.collect()
log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 2: DrawViTPose]")
### Step 3: WanVideoAnimateEmbeds
if "PFWanVideoAnimateEmbeds" in nodes.NODE_CLASS_MAPPINGS:
WanVideoAnimateEmbedsClass = nodes.NODE_CLASS_MAPPINGS["PFWanVideoAnimateEmbeds"]
wan_video_obj = WanVideoAnimateEmbedsClass()
image_embeds = wan_video_obj.process(
vae, width, height, num_frames, force_offload, frame_window_size, colormatch, pose_strength, face_strength,
ref_images=ref_images, pose_images=pose_images, face_images=face_images, clip_embeds=clip_embeds,
tiled_vae=tiled_vae, bg_images=bg_images, mask=mask
)[0]
# del wan_video_obj
# face_images.set_(torch.empty(0, device=face_images.device))
# pose_images.set_(torch.empty(0, device=pose_images.device))
# del face_images, pose_images
else:
raise ImportError("Could not find WanVideoAnimateEmbeds node")
log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 3: WanVideoAnimateEmbeds]")
mm.soft_empty_cache()
gc.collect()
log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after gc]")
return (image_embeds, num_padded_frames, frame_count)
```