## WanVideoAnimateEmbeds ```python= class WanVideoAnimateEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("WANVAE",), "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), "force_offload": ("BOOLEAN", {"default": True}), "frame_window_size": ("INT", {"default": 77, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of frames to use for temporal attention window"}), "colormatch": ( [ 'disabled', 'mkl', 'hm', 'reinhard', 'mvgd', 'hm-mvgd-hm', 'hm-mkl-hm', ], { "default": 'disabled', "tooltip": "Color matching method to use between the windows" },), "pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the pose"}), "face_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the face"}), }, "optional": { "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), "ref_images": ("IMAGE", {"tooltip": "Image to encode"}), "pose_images": ("IMAGE", {"tooltip": "end frame"}), "face_images": ("IMAGE", {"tooltip": "end frame"}), "bg_images": ("IMAGE", {"tooltip": "background images"}), "mask": ("MASK", {"tooltip": "mask"}), "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "pf_custom_nodes/WanVideoWrapper" def process(self, vae, width, height, num_frames, force_offload, frame_window_size, colormatch, pose_strength, face_strength, ref_images=None, pose_images=None, face_images=None, clip_embeds=None, tiled_vae=False, bg_images=None, mask=None): H = height W = width lat_h = H // vae.upsampling_factor lat_w = W // vae.upsampling_factor num_refs = ref_images.shape[0] if ref_images is not None else 0 num_frames = ((num_frames - 1) // 4) * 4 + 1 looping = num_frames > frame_window_size if num_frames < frame_window_size: frame_window_size = num_frames target_shape = (16, (num_frames - 1) // 4 + 1 + num_refs, lat_h, lat_w) latent_window_size = ((frame_window_size - 1) // 4) if not looping: num_frames = num_frames + num_refs * 4 else: latent_window_size = latent_window_size + 1 mm.soft_empty_cache() gc.collect() vae.to(device) # Resize and rearrange the input image dimensions pose_latents = ref_latents = ref_latent = None if pose_images is not None: pose_images = pose_images[..., :3] if pose_images.shape[1] != H or pose_images.shape[2] != W: resized_pose_images = common_upscale(pose_images.movedim(-1, 1), W, H, "lanczos", "center").movedim(0, 1) else: resized_pose_images = pose_images.permute(3, 0, 1, 2) # C, T, H, W resized_pose_images = resized_pose_images * 2 - 1 if not looping: pose_latents = vae.encode([resized_pose_images.to(device, vae.dtype)], device,tiled=tiled_vae) pose_latents = pose_latents.to(offload_device) if pose_latents.shape[2] < latent_window_size: print(f"WanAnimate: Padding pose latents from {pose_latents.shape} to length {latent_window_size}") pad_len = latent_window_size - pose_latents.shape[2] pad = torch.zeros(pose_latents.shape[0], pose_latents.shape[1], pad_len, pose_latents.shape[3], pose_latents.shape[4], device=pose_latents.device, dtype=pose_latents.dtype) pose_latents = torch.cat([pose_latents, pad], dim=2) del resized_pose_images else: resized_pose_images = resized_pose_images.to(offload_device, dtype=vae.dtype) bg_latents = None if bg_images is not None: if bg_images.shape[1] != H or bg_images.shape[2] != W: resized_bg_images = common_upscale(bg_images.movedim(-1, 1), W, H, "lanczos", "center").movedim(0, 1) else: resized_bg_images = bg_images.permute(3, 0, 1, 2) # C, T, H, W resized_bg_images = (resized_bg_images[:3] * 2 - 1) if not looping: if bg_images is None: resized_bg_images = torch.zeros(3, num_frames - num_refs, H, W, device=device, dtype=vae.dtype) bg_latents = vae.encode([resized_bg_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0].to(offload_device) del resized_bg_images elif bg_images is not None: resized_bg_images = resized_bg_images.to(offload_device, dtype=vae.dtype) if ref_images is not None: if ref_images.shape[1] != H or ref_images.shape[2] != W: resized_ref_images = common_upscale(ref_images.movedim(-1, 1), W, H, "lanczos", "center").movedim(0, 1) else: resized_ref_images = ref_images.permute(3, 0, 1, 2) # C, T, H, W resized_ref_images = resized_ref_images[:3] * 2 - 1 ref_latent = vae.encode([resized_ref_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0] msk = torch.zeros(4, 1, lat_h, lat_w, device=device, dtype=vae.dtype) msk[:, :num_refs] = 1 import logging logging.info(f"WanVideoAnimateEmbeds: ref_latent shape {ref_latent.shape}") logging.info(f"WanVideoAnimateEmbeds: msk shape {msk.shape}") ref_latent_masked = torch.cat([msk, ref_latent], dim=0).to(offload_device) # 4+C 1 H W if mask is None: bg_mask = torch.zeros(1, num_frames, lat_h, lat_w, device=offload_device, dtype=vae.dtype) else: bg_mask = 1 - mask[:num_frames] if bg_mask.shape[0] < num_frames and not looping: bg_mask = torch.cat([bg_mask, bg_mask[-1:].repeat(num_frames - bg_mask.shape[0], 1, 1)], dim=0) bg_mask = common_upscale(bg_mask.unsqueeze(1), lat_w, lat_h, "nearest", "center").squeeze(1) bg_mask = bg_mask.unsqueeze(-1).permute(3, 0, 1, 2).to(offload_device, vae.dtype) # C, T, H, W if bg_images is None and looping: bg_mask[:, :num_refs] = 1 bg_mask_mask_repeated = torch.repeat_interleave(bg_mask[:, 0:1], repeats=4, dim=1) # T, C, H, W bg_mask = torch.cat([bg_mask_mask_repeated, bg_mask[:, 1:]], dim=1) bg_mask = bg_mask.view(1, bg_mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W bg_mask = bg_mask.movedim(1, 2)[0]# C, T, H, W if not looping: bg_latents_masked = torch.cat([bg_mask[:, :bg_latents.shape[1]], bg_latents], dim=0) del bg_mask, bg_latents ref_latent = torch.cat([ref_latent_masked, bg_latents_masked], dim=1) else: ref_latent = ref_latent_masked if face_images is not None: face_images = face_images[..., :3] if face_images.shape[1] != 512 or face_images.shape[2] != 512: resized_face_images = common_upscale(face_images.movedim(-1, 1), 512, 512, "lanczos", "center").movedim(0, 1) else: resized_face_images = face_images.permute(3, 0, 1, 2) # B, C, T, H, W resized_face_images = (resized_face_images * 2 - 1).unsqueeze(0) resized_face_images = resized_face_images.to(offload_device, dtype=vae.dtype) seq_len = math.ceil((target_shape[2] * target_shape[3]) / 4 * target_shape[1]) if force_offload: vae.model.to(offload_device) mm.soft_empty_cache() gc.collect() image_embeds = { "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, "max_seq_len": seq_len, "pose_latents": pose_latents, "pose_images": resized_pose_images if pose_images is not None and looping else None, "bg_images": resized_bg_images if bg_images is not None and looping else None, "ref_masks": bg_mask if mask is not None and looping else None, "is_masked": mask is not None, "ref_latent": ref_latent, "ref_image": resized_ref_images if ref_images is not None else None, "face_pixels": resized_face_images if face_images is not None else None, "num_frames": num_frames, "target_shape": target_shape, "frame_window_size": frame_window_size, "lat_h": lat_h, "lat_w": lat_w, "vae": vae, "colormatch": colormatch, "looping": looping, "pose_strength": pose_strength, "face_strength": face_strength, } return (image_embeds,) ``` ## WanVideoSampler ```python= # region wananimate loop elif wananimate_loop: # calculate frame counts total_frames = num_frames refert_num = 1 real_clip_len = frame_window_size - refert_num last_clip_num = (total_frames - refert_num) % real_clip_len extra = 0 if last_clip_num == 0 else real_clip_len - last_clip_num target_len = total_frames + extra estimated_iterations = target_len // real_clip_len target_latent_len = (target_len - 1) // 4 + estimated_iterations latent_window_size = (frame_window_size - 1) // 4 + 1 from .utils import tensor_pingpong_pad ref_latent = image_embeds.get("ref_latent", None) ref_images = image_embeds.get("ref_image", None) bg_images = image_embeds.get("bg_images", None) pose_images = image_embeds.get("pose_images", None) current_ref_images = face_images = face_images_in = None if wananim_face_pixels is not None: face_images = tensor_pingpong_pad(wananim_face_pixels, target_len) log.info(f"WanAnimate: Face input {wananim_face_pixels.shape} padded to shape {face_images.shape}") if wananim_ref_masks is not None: ref_masks_in = tensor_pingpong_pad(wananim_ref_masks, target_latent_len) log.info(f"WanAnimate: Ref masks {wananim_ref_masks.shape} padded to shape {ref_masks_in.shape}") if bg_images is not None: bg_images_in = tensor_pingpong_pad(bg_images, target_len) log.info(f"WanAnimate: BG images {bg_images.shape} padded to shape {bg_images.shape}") if pose_images is not None: pose_images_in = tensor_pingpong_pad(pose_images, target_len) log.info(f"WanAnimate: Pose images {pose_images.shape} padded to shape {pose_images_in.shape}") # init variables offloaded = False colormatch = image_embeds.get("colormatch", "disabled") output_path = image_embeds.get("output_path", "") offload = image_embeds.get("force_offload", False) lat_h, lat_w = noise.shape[2], noise.shape[3] start = start_latent = img_counter = step_iteration_count = iteration_count = 0 end = frame_window_size end_latent = latent_window_size callback = prepare_callback(patcher, estimated_iterations) log.info(f"Sampling {total_frames} frames in {estimated_iterations} windows, at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps") # outer WanAnimate loop gen_video_list = [] while True: if start + refert_num >= total_frames: break mm.soft_empty_cache() mask_reft_len = 0 if start == 0 else refert_num self.cache_state = [None, None] noise = torch.randn(16, latent_window_size + 1, lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device) seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) if current_ref_images is not None or bg_images is not None or ref_latent is not None: if offload: offload_transformer(transformer, remove_lora=False) offloaded = True vae.to(device) if wananim_ref_masks is not None: msk = ref_masks_in[:, start_latent:end_latent].to(device, dtype) else: msk = torch.zeros(4, latent_window_size, lat_h, lat_w, device=device, dtype=dtype) if bg_images is not None: bg_image_slice = bg_images_in[:, start:end].to(device) else: bg_image_slice = torch.zeros(3, frame_window_size-refert_num, lat_h * 8, lat_w * 8, device=device, dtype=vae.dtype) if mask_reft_len == 0: temporal_ref_latents = vae.encode([bg_image_slice], device,tiled=tiled_vae)[0] else: concatenated = torch.cat([current_ref_images.to(device, dtype=vae.dtype), bg_image_slice[:, mask_reft_len:]], dim=1) temporal_ref_latents = vae.encode([concatenated.to(device, vae.dtype)], device,tiled=tiled_vae, pbar=False)[0] msk[:, :mask_reft_len] = 1 if msk.shape[1] != temporal_ref_latents.shape[1]: if temporal_ref_latents.shape[1] < msk.shape[1]: pad_len = msk.shape[1] - temporal_ref_latents.shape[1] pad_tensor = temporal_ref_latents[:, -1:].repeat(1, pad_len, 1, 1) temporal_ref_latents = torch.cat([temporal_ref_latents, pad_tensor], dim=1) else: temporal_ref_latents = temporal_ref_latents[:, :msk.shape[1]] if ref_latent is not None: temporal_ref_latents = torch.cat([msk, temporal_ref_latents], dim=0) # 4+C T H W image_cond_in = torch.cat([ref_latent.to(device), temporal_ref_latents], dim=1) # 4+C T+trefs H W del temporal_ref_latents, msk, bg_image_slice else: image_cond_in = torch.cat([torch.tile(torch.zeros_like(noise[:1]), [4, 1, 1, 1]), torch.zeros_like(noise)], dim=0).to(device) else: image_cond_in = torch.cat([torch.tile(torch.zeros_like(noise[:1]), [4, 1, 1, 1]), torch.zeros_like(noise)], dim=0).to(device) pose_input_slice = None if pose_images is not None: vae.to(device) pose_image_slice = pose_images_in[:, start:end].to(device) pose_input_slice = vae.encode([pose_image_slice], device,tiled=tiled_vae, pbar=False).to(dtype) vae.to(offload_device) if wananim_face_pixels is None and wananim_ref_masks is not None: face_images_in = torch.zeros(1, 3, frame_window_size, 512, 512, device=device, dtype=torch.float32) elif wananim_face_pixels is not None: face_images_in = face_images[:, :, start:end].to(device, torch.float32) if face_images is not None else None if samples is not None: input_samples = samples["samples"] if input_samples is not None: input_samples = input_samples.squeeze(0).to(noise) # Check if we have enough frames in input_samples if latent_end_idx > input_samples.shape[1]: # We need more frames than available - pad the input_samples at the end pad_length = latent_end_idx - input_samples.shape[1] last_frame = input_samples[:, -1:].repeat(1, pad_length, 1, 1) input_samples = torch.cat([input_samples, last_frame], dim=1) input_samples = input_samples[:, latent_start_idx:latent_end_idx] if noise_mask is not None: original_image = input_samples.to(device) assert input_samples.shape[1] == noise.shape[1], f"Slice mismatch: {input_samples.shape[1]} vs {noise.shape[1]}" if add_noise_to_samples: latent_timestep = timesteps[0] noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples else: noise = input_samples # diff diff prep noise_mask = samples.get("noise_mask", None) if noise_mask is not None: if len(noise_mask.shape) == 4: noise_mask = noise_mask.squeeze(1) if noise_mask.shape[0] < noise.shape[1]: noise_mask = noise_mask.repeat(noise.shape[1] // noise_mask.shape[0], 1, 1) else: noise_mask = noise_mask[start_latent:end_latent] noise_mask = torch.nn.functional.interpolate( noise_mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W] size=(noise.shape[1], noise.shape[2], noise.shape[3]), mode='trilinear', align_corners=False ).repeat(1, noise.shape[0], 1, 1, 1) thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps) thresholds = thresholds.reshape(-1, 1, 1, 1, 1).to(device) masks = (1-noise_mask.repeat(len(timesteps), 1, 1, 1, 1).to(device)) > thresholds if isinstance(scheduler, dict): sample_scheduler = copy.deepcopy(scheduler["sample_scheduler"]) timesteps = scheduler["timesteps"] else: sample_scheduler, timesteps,_,_ = get_scheduler(scheduler, total_steps, start_step, end_step, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) # sample videos latent = noise if offloaded: # Load weights if transformer.patched_linear and gguf_reader is None: load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) elif gguf_reader is not None: #handle GGUF load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) #blockswap init init_blockswap(transformer, block_swap_args, model) # Use the appropriate prompt for this section if len(text_embeds["prompt_embeds"]) > 1: prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1) positive = [text_embeds["prompt_embeds"][prompt_index]] log.info(f"Using prompt index: {prompt_index}") else: positive = text_embeds["prompt_embeds"] # uni3c slices uni3c_data_input = None if uni3c_embeds is not None: render_latent = uni3c_embeds["render_latent"][:,:,start_latent:end_latent].to(device) if render_latent.shape[2] < noise.shape[1]: render_latent = torch.nn.functional.interpolate(render_latent, size=(noise.shape[1], noise.shape[2], noise.shape[3]), mode='trilinear', align_corners=False) uni3c_data_input = {"render_latent": render_latent} for k in uni3c_data: if k != "render_latent": uni3c_data_input[k] = uni3c_data[k] mm.soft_empty_cache() gc.collect() # inner WanAnimate sampling loop sampling_pbar = tqdm(total=len(timesteps), desc=f"Frames {start}-{end}", position=0, leave=True) for i in range(len(timesteps)): timestep = timesteps[i] latent_model_input = latent.to(device) noise_pred, _, self.cache_state = predict_with_cfg( latent_model_input, cfg[min(i, len(timesteps)-1)], positive, text_embeds["negative_prompt_embeds"], timestep, i, cache_state=self.cache_state, image_cond=image_cond_in, clip_fea=clip_fea, wananim_face_pixels=face_images_in, wananim_pose_latents=pose_input_slice, uni3c_data=uni3c_data_input, ) if callback is not None: callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) callback(step_iteration_count, callback_latent, None, estimated_iterations*(len(timesteps))) del callback_latent sampling_pbar.update(1) step_iteration_count += 1 if use_tsr: noise_pred = temporal_score_rescaling(noise_pred, latent, timestep, tsr_k, tsr_sigma) latent = sample_scheduler.step(noise_pred.unsqueeze(0), timestep, latent.unsqueeze(0).to(noise_pred.device), **scheduler_step_args)[0].squeeze(0) del noise_pred, latent_model_input, timestep # differential diffusion inpaint if masks is not None: if i < len(timesteps) - 1: image_latent = add_noise(original_image.to(device), noise.to(device), timesteps[i+1]) mask = masks[i].to(latent) latent = image_latent * mask + latent * (1-mask) del noise if offload: offload_transformer(transformer, remove_lora=False) offloaded = True vae.to(device) videos = vae.decode(latent[:, 1:].unsqueeze(0).to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False)[0].cpu() del latent if start != 0: videos = videos[:, refert_num:] sampling_pbar.close() # optional color correction if colormatch != "disabled": videos = videos.permute(1, 2, 3, 0).float().numpy() from color_matcher import ColorMatcher cm = ColorMatcher() cm_result_list = [] for img in videos: cm_result = cm.transfer(src=img, ref=ref_images.permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch) cm_result_list.append(torch.from_numpy(cm_result).to(vae.dtype)) videos = torch.stack(cm_result_list, dim=0).permute(3, 0, 1, 2) del cm_result_list current_ref_images = videos[:, -refert_num:].clone().detach() # optionally save generated samples to disk if output_path: video_np = videos.clamp(-1.0, 1.0).add(1.0).div(2.0).mul(255).cpu().float().numpy().transpose(1, 2, 3, 0).astype('uint8') num_frames_to_save = video_np.shape[0] if is_first_clip else video_np.shape[0] - cur_motion_frames_num log.info(f"Saving {num_frames_to_save} generated frames to {output_path}") start_idx = 0 if is_first_clip else cur_motion_frames_num for i in range(start_idx, video_np.shape[0]): im = Image.fromarray(video_np[i]) im.save(os.path.join(output_path, f"frame_{img_counter:05d}.png")) img_counter += 1 else: gen_video_list.append(videos) del videos iteration_count += 1 start += frame_window_size - refert_num end += frame_window_size - refert_num start_latent += latent_window_size - ((refert_num - 1)// 4 + 1) end_latent += latent_window_size - ((refert_num - 1)// 4 + 1) if not output_path: gen_video_samples = torch.cat(gen_video_list, dim=1) else: gen_video_samples = torch.zeros(3, 1, 64, 64) # dummy output if force_offload: vae.to(offload_device) if not model["auto_cpu_offload"]: offload_transformer(transformer) try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) except: pass return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path}, ``` ## SuperWanVideoAnimateEmbeds ```python= class SuperWanVideoAnimateEmbeds: """ Super node that chains multiple WanAnimate operations: WanAnimatePreprocessor -> PoseAndFaceDetection -> DrawViTPose -> WanVideoAnimateEmbeds This avoids intermediate caching by executing all operations in a single node. """ @classmethod def INPUT_TYPES(cls): return { "required": { # WanAnimatePreprocessor inputs "images": ("IMAGE",), # PoseAndFaceDetection inputs "model": ("POSEMODEL",), # DrawViTPose inputs "retarget_padding": ("INT", {"default": 16, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the retargeted pose image is padded and resized to the target size"}), "body_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the body sticks. Set to 0 to disable body drawing, -1 for auto"}), "hand_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the hand sticks. Set to 0 to disable hand drawing, -1 for auto"}), "draw_head": ("BOOLEAN", {"default": True, "tooltip": "Whether to draw head keypoints"}), # WanVideoAnimateEmbeds inputs "vae": ("WANVAE",), "force_offload": ("BOOLEAN", {"default": True}), "frame_window_size": ("INT", {"default": 77, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of frames to use for temporal attention window"}), "colormatch": ( [ 'disabled', 'mkl', 'hm', 'reinhard', 'mvgd', 'hm-mvgd-hm', 'hm-mkl-hm', ], { "default": 'disabled', "tooltip": "Color matching method to use between the windows" }, ), "pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the pose"}), "face_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the face"}), }, "optional": { # WanAnimatePreprocessor optional inputs "is_mixing_mode": ("BOOLEAN", {"default": False, "tooltip": "Enable mixing mode"}), "target_width": ("INT", {"default": 0, "min": 0, "max": 2048, "step": 8, "tooltip": "Target width (0 = auto)"}), "target_height": ("INT", {"default": 0, "min": 0, "max": 2048, "step": 8, "tooltip": "Target height (0 = auto)"}), "force_rate": ("INT", {"default": 16, "min": 1, "max": 60, "step": 1, "tooltip": "Frame rate"}), "padding": ("BOOLEAN", {"default": True, "tooltip": "Enable padding"}), # PoseAndFaceDetection optional inputs "retarget_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}), # WanVideoAnimateEmbeds optional inputs "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), "ref_images": ("IMAGE", {"tooltip": "Image to encode"}), "bg_images": ("IMAGE", {"tooltip": "background images"}), "mask": ("MASK", {"tooltip": "mask"}), "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "INT", "INT") RETURN_NAMES = ("image_embeds", "num_padded_frames", "frame_count") FUNCTION = "process" CATEGORY = "custom" DESCRIPTION = "Processes images through WanAnimatePreprocessor -> PoseAndFaceDetection -> DrawViTPose -> WanVideoAnimateEmbeds in a single node to avoid intermediate caching." def process(self, model, images, retarget_padding, body_stick_width, hand_stick_width, draw_head, vae, force_offload, frame_window_size, colormatch, pose_strength, face_strength, is_mixing_mode=False, target_width=0, target_height=0, force_rate=16, padding=True, retarget_image=None, clip_embeds=None, ref_images=None, bg_images=None, mask=None, tiled_vae=False): """ Execute the full pipeline without intermediate caching. """ log_container_memory_usage("[SuperWanVideoAnimateEmbeds - before process]") ### Step 0: WanAnimatePreprocessor if "WanAnimatePreprocessor" in nodes.NODE_CLASS_MAPPINGS: WanAnimatePreprocessorClass = nodes.NODE_CLASS_MAPPINGS["WanAnimatePreprocessor"] preprocessor_obj = WanAnimatePreprocessorClass() processed_images, num_padded_frames, width, height = preprocessor_obj.execute( images, is_mixing_mode=is_mixing_mode, target_width=target_width, target_height=target_height, force_rate=force_rate, padding=padding ) images.set_(torch.empty(0)) # del preprocessor_obj else: raise ImportError("Could not find WanAnimatePreprocessor node") log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 0: WanAnimatePreprocessor]") # Calculate num_frames from processed images (which already includes padded frames) num_frames = processed_images.shape[0] frame_count = num_frames ### Step 1: PoseAndFaceDetection (using processed images and dimensions from WanAnimatePreprocessor) if "PFPoseAndFaceDetection" in nodes.NODE_CLASS_MAPPINGS: PFPoseAndFaceDetectionClass = nodes.NODE_CLASS_MAPPINGS["PFPoseAndFaceDetection"] pose_detection_obj = PFPoseAndFaceDetectionClass() pose_data, face_images, key_frame_body_points, bboxes, face_bboxes = pose_detection_obj.process( model, processed_images, width, height, retarget_image ) # del pose_detection_obj # processed_images.set_(torch.empty(0)) # del processed_images else: raise ImportError("Could not find PFPoseAndFaceDetection node") log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 1: PoseAndFaceDetection]") ### Step 2: DrawViTPose if "PFDrawViTPose" in nodes.NODE_CLASS_MAPPINGS: PFDrawViTPoseClass = nodes.NODE_CLASS_MAPPINGS["PFDrawViTPose"] draw_pose_obj = PFDrawViTPoseClass() pose_images = draw_pose_obj.process( pose_data, width, height, body_stick_width, hand_stick_width, draw_head, retarget_padding )[0] # del draw_pose_obj else: raise ImportError("Could not find PFDrawViTPose node") mm.soft_empty_cache() gc.collect() log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 2: DrawViTPose]") ### Step 3: WanVideoAnimateEmbeds if "PFWanVideoAnimateEmbeds" in nodes.NODE_CLASS_MAPPINGS: WanVideoAnimateEmbedsClass = nodes.NODE_CLASS_MAPPINGS["PFWanVideoAnimateEmbeds"] wan_video_obj = WanVideoAnimateEmbedsClass() image_embeds = wan_video_obj.process( vae, width, height, num_frames, force_offload, frame_window_size, colormatch, pose_strength, face_strength, ref_images=ref_images, pose_images=pose_images, face_images=face_images, clip_embeds=clip_embeds, tiled_vae=tiled_vae, bg_images=bg_images, mask=mask )[0] # del wan_video_obj # face_images.set_(torch.empty(0, device=face_images.device)) # pose_images.set_(torch.empty(0, device=pose_images.device)) # del face_images, pose_images else: raise ImportError("Could not find WanVideoAnimateEmbeds node") log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after Step 3: WanVideoAnimateEmbeds]") mm.soft_empty_cache() gc.collect() log_container_memory_usage("[SuperWanVideoAnimateEmbeds - after gc]") return (image_embeds, num_padded_frames, frame_count) ```