Compare commits

...

25 Commits

Author SHA1 Message Date
Jedrzej Kosinski 386e854aab Merge branch 'master' into flipflop-stream 2025-10-28 15:08:27 -07:00
Jedrzej Kosinski 61133af772 Add '--flipflop-offload' startup argument 2025-10-13 21:10:44 -07:00
Jedrzej Kosinski 586a8de8da Merge branch 'master' into flipflop-stream 2025-10-13 21:04:37 -07:00
Jedrzej Kosinski 5329180fce Made flipflop consider partial_unload, partial_offload, and add flip+flop to mem counters 2025-10-03 16:21:01 -07:00
Jedrzej Kosinski 0fdd327c2f Merge branch 'master' into flipflop-stream 2025-10-03 14:32:56 -07:00
Jedrzej Kosinski ee01002e63 Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction 2025-10-02 22:02:50 -07:00
Jedrzej Kosinski 831c3cf05e Add a temporary workaround for odd amount of blocks not producing expected results 2025-10-02 20:29:11 -07:00
Jedrzej Kosinski 0d8e8abd90 Default ro smaller blocks getting flipflopped first 2025-10-02 18:00:21 -07:00
Jedrzej Kosinski d5001ed90e Make flux support flipflop 2025-10-02 17:53:22 -07:00
Jedrzej Kosinski 8d7b22b720 Fixed FlipFlipModule.execute_blocks having hardcoded strings from Qwen 2025-10-02 17:49:43 -07:00
Jedrzej Kosinski 6d3ec9fcf3 Simplified flipflop setup by adding FlipFlopModule.execute_blocks helper 2025-10-02 16:46:37 -07:00
Jedrzej Kosinski c4420b6a41 Change log string slightly 2025-10-02 15:34:35 -07:00
Jedrzej Kosinski a282586995 Merge branch 'master' into flipflop-stream 2025-10-02 15:03:26 -07:00
Jedrzej Kosinski 0df61b5032 Fix improper index slicing for flipflop get blocks, add extra log message 2025-10-01 21:21:36 -07:00
Jedrzej Kosinski 7c896c5567 Initial automatic support for flipflop within ModelPatcher - only Qwen Image diffusion_model uses FlipFlopModule currently 2025-10-01 20:13:50 -07:00
Jedrzej Kosinski ec156e72eb Merge branch 'master' into flipflop-stream 2025-09-30 23:08:37 -07:00
Jedrzej Kosinski 01f4512bf8 In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying 2025-09-30 23:08:08 -07:00
Jedrzej Kosinski d0bd221495 Merge branch 'master' into flipflop-stream 2025-09-29 22:49:38 -07:00
Jedrzej Kosinski 8a8162e8da Fix percentage logic, begin adding elements to ModelPatcher to track flip flop compatibility 2025-09-29 22:49:12 -07:00
Jedrzej Kosinski ff789c8beb Merge branch 'master' into flipflop-stream 2025-09-29 16:09:51 -07:00
Jedrzej Kosinski 0e966dcf85 Merge branch 'master' into flipflop-stream 2025-09-27 21:13:26 -07:00
Jedrzej Kosinski 6b240b0bce Refactored old flip flop into a new implementation that allows for controlling the percentage of blocks getting flip flopped, converted nodes to v3 schema 2025-09-25 22:41:41 -07:00
Jedrzej Kosinski f9fbf902d5 Added missing Qwen block params, further subdivided blocks function 2025-09-25 17:49:39 -07:00
Jedrzej Kosinski f083720eb4 Refactored FlipFlopTransformer.__call__ to fully separate out actions between flip and flop 2025-09-25 16:16:51 -07:00
Jedrzej Kosinski 84e73f2aa5 Brought over flip flop prototype from contentis' fork, limiting it to only Qwen to ease the process of adapting it to be a native feature 2025-09-25 16:15:46 -07:00
10 changed files with 530 additions and 126 deletions

View File

@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--flipflop-offload", action="store_true", help="Use async flipflop weight offloading for supported DiT models.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")

View File

@ -0,0 +1,200 @@
from __future__ import annotations
import torch
import copy
import comfy.model_management
class FlipFlopModule(torch.nn.Module):
def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True):
super().__init__()
self.block_types = block_types
self.enable_flipflop = enable_flipflop
self.flipflop: dict[str, FlipFlopHolder] = {}
self.block_info: dict[str, tuple[int, int]] = {}
self.flipflop_prefixes: list[str] = []
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device):
for block_type, (flipflop_blocks, total_blocks) in block_info.items():
if block_type in self.flipflop:
continue
self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device)
self.block_info[block_type] = (flipflop_blocks, total_blocks)
self.flipflop_prefixes = flipflop_prefixes.copy()
def init_flipflop_block_copies(self, device: torch.device) -> int:
memory_freed = 0
for holder in self.flipflop.values():
memory_freed += holder.init_flipflop_block_copies(device)
return memory_freed
def clean_flipflop_holders(self):
memory_freed = 0
for block_type in list(self.flipflop.keys()):
memory_freed += self.flipflop[block_type].clean_flipflop_blocks()
del self.flipflop[block_type]
self.block_info = {}
self.flipflop_prefixes = []
return memory_freed
def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]:
return getattr(self, block_type)
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
if block_type not in self.block_types:
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
if block_type in self.flipflop:
return getattr(self, block_type)[:self.flipflop[block_type].i_offset]
return getattr(self, block_type)
def get_all_block_module_sizes(self, reverse_sort_by_size: bool = False) -> list[tuple[str, int]]:
'''
Returns a list of (block_type, size) sorted by size.
If reverse_sort_by_size is True, the list is sorted by size in reverse order.
'''
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
sizes.sort(key=lambda x: x[1], reverse=reverse_sort_by_size)
return sizes
def get_block_module_size(self, block_type: str) -> int:
return comfy.model_management.module_size(getattr(self, block_type)[0])
def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs):
# execute blocks, supporting both single and double (or higher) block types
if isinstance(out, torch.Tensor):
out = (out,)
for i, block in enumerate(self.get_blocks(block_type)):
out = func(i, block, *out, *args, **kwargs)
if isinstance(out, torch.Tensor):
out = (out,)
if block_type in self.flipflop:
holder = self.flipflop[block_type]
with holder.context() as ctx:
for i, block in enumerate(holder.blocks):
out = ctx(func, i, block, *out, *args, **kwargs)
if isinstance(out, torch.Tensor):
out = (out,)
if len(out) == 1:
out = out[0]
return out
class FlipFlopContext:
def __init__(self, holder: FlipFlopHolder):
# NOTE: there is a bug when there are an odd number of blocks to flipflop.
# Worked around right now by always making sure it will be even, but need to resolve.
self.holder = holder
self.reset()
def reset(self):
self.num_blocks = len(self.holder.blocks)
self.first_flip = True
self.first_flop = True
self.last_flip = False
self.last_flop = False
def __enter__(self):
self.reset()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.holder.compute_stream.record_event(self.holder.cpy_end_event)
def do_flip(self, func, i: int, _, *args, **kwargs):
# flip
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
with torch.cuda.stream(self.holder.compute_stream):
out = func(i+self.holder.i_offset, self.holder.flip, *args, **kwargs)
self.holder.event_flip.record(self.holder.compute_stream)
# while flip executes, queue flop to copy to its next block
next_flop_i = i + 1
if next_flop_i >= self.num_blocks:
next_flop_i = next_flop_i - self.num_blocks
self.last_flip = True
if not self.first_flip:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
if self.last_flip:
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
self.first_flip = False
return out
def do_flop(self, func, i: int, _, *args, **kwargs):
# flop
if not self.first_flop:
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
with torch.cuda.stream(self.holder.compute_stream):
out = func(i+self.holder.i_offset, self.holder.flop, *args, **kwargs)
self.holder.event_flop.record(self.holder.compute_stream)
# while flop executes, queue flip to copy to its next block
next_flip_i = i + 1
if next_flip_i >= self.num_blocks:
next_flip_i = next_flip_i - self.num_blocks
self.last_flop = True
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
if self.last_flop:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
self.first_flop = False
return out
@torch.no_grad()
def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs):
# flips are even indexes, flops are odd indexes
if i % 2 == 0:
return self.do_flip(func, i, block, *args, **kwargs)
else:
return self.do_flop(func, i, block, *args, **kwargs)
class FlipFlopHolder:
def __init__(self, blocks: list[torch.nn.Module], flip_amount: int, total_amount: int, load_device: torch.device, offload_device: torch.device):
self.load_device = load_device
self.offload_device = offload_device
self.blocks = blocks
self.flip_amount = flip_amount
self.total_amount = total_amount
# NOTE: used to make sure block indexes passed into block functions match expected patch indexes
self.i_offset = total_amount - flip_amount
self.block_module_size = 0
if len(self.blocks) > 0:
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
self.flip: torch.nn.Module = None
self.flop: torch.nn.Module = None
self.compute_stream = torch.cuda.default_stream(self.load_device)
self.cpy_stream = torch.cuda.Stream(self.load_device)
self.event_flip = torch.cuda.Event(enable_timing=False)
self.event_flop = torch.cuda.Event(enable_timing=False)
self.cpy_end_event = torch.cuda.Event(enable_timing=False)
# INIT - is this actually needed?
self.compute_stream.record_event(self.cpy_end_event)
def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy_end_event: torch.cuda.Event=None):
if cpy_start_event:
self.cpy_stream.wait_event(cpy_start_event)
with torch.cuda.stream(self.cpy_stream):
for k, v in src.items():
dst[k].copy_(v, non_blocking=True)
if cpy_end_event:
cpy_end_event.record(self.cpy_stream)
def context(self):
return FlipFlopContext(self)
def init_flipflop_block_copies(self, load_device: torch.device) -> int:
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop)
def clean_flipflop_blocks(self) -> int:
memory_freed = 0
memory_freed += comfy.model_management.module_size(self.flip)
memory_freed += comfy.model_management.module_size(self.flop)
del self.flip
del self.flop
self.flip = None
self.flop = None
return memory_freed

View File

@ -7,6 +7,7 @@ from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flipflop_transformer import FlipFlopModule
from .layers import (
DoubleStreamBlock,
@ -35,13 +36,13 @@ class FluxParams:
guidance_embed: bool
class Flux(nn.Module):
class Flux(FlipFlopModule):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
super().__init__(("double_blocks", "single_blocks"))
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
@ -89,6 +90,72 @@ class Flux(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img[:, :add.shape[1]] += add
return img, txt
def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
return img
def forward_orig(
self,
img: Tensor,
@ -136,74 +203,16 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img[:, :add.shape[1]] += add
# execute double blocks
img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options)
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
# execute single blocks
img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options)
img = img[:, txt.shape[1] :, ...]

View File

@ -5,11 +5,13 @@ import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.flipflop_transformer import FlipFlopModule
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
import comfy.ops
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@ -283,7 +285,7 @@ class LastLayer(nn.Module):
return x
class QwenImageTransformer2DModel(nn.Module):
class QwenImageTransformer2DModel(FlipFlopModule):
def __init__(
self,
patch_size: int = 2,
@ -300,9 +302,9 @@ class QwenImageTransformer2DModel(nn.Module):
final_layer=True,
dtype=None,
device=None,
operations=None,
operations: comfy.ops.disable_weight_init=None,
):
super().__init__()
super().__init__(block_types=("transformer_blocks",))
self.dtype = dtype
self.patch_size = patch_size
self.in_channels = in_channels
@ -366,6 +368,40 @@ class QwenImageTransformer2DModel(nn.Module):
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
return hidden_states, encoder_hidden_states
def _forward(
self,
x,
@ -433,37 +469,8 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
out = (hidden_states, encoder_hidden_states)
hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flipflop_transformer import FlipFlopModule
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module):
return clip_extra_context_tokens
class WanModel(torch.nn.Module):
class WanModel(FlipFlopModule):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
@ -412,6 +413,7 @@ class WanModel(torch.nn.Module):
device=None,
dtype=None,
operations=None,
enable_flipflop=True,
):
r"""
Initialize the diffusion model backbone.
@ -449,7 +451,7 @@ class WanModel(torch.nn.Module):
Epsilon value for normalization layers
"""
super().__init__()
super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop)
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
@ -506,6 +508,18 @@ class WanModel(torch.nn.Module):
else:
self.ref_conv = None
def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
return x
def forward_orig(
self,
x,
@ -567,16 +581,8 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# execute blocks
x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options)
# head
x = self.head(x, e)
@ -688,7 +694,7 @@ class VaceWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
# Vace
@ -808,7 +814,7 @@ class CameraWanModel(WanModel):
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)

View File

@ -426,7 +426,7 @@ class AnimateWanModel(WanModel):
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.pose_patch_embedding = operations.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype

View File

@ -1006,6 +1006,8 @@ def force_channels_last():
#TODO
return False
def flipflop_enabled():
return args.flipflop_offload
STREAMS = {}
NUM_STREAMS = 1

View File

@ -25,7 +25,7 @@ import logging
import math
import uuid
from typing import Callable, Optional
import time # TODO remove
import torch
import comfy.float
@ -591,7 +591,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None):
if key not in self.patches:
return
@ -611,15 +611,103 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if device_final is not None:
out_weight = out_weight.to(device_final)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
if device_final is not None:
out_weight = out_weight.to(device_final)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self):
def supports_flipflop(self):
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
if not comfy.model_management.flipflop_enabled():
return False
if not hasattr(self.model, "diffusion_model"):
return False
if not getattr(self.model.diffusion_model, "enable_flipflop", False):
return False
if not comfy.model_management.is_nvidia():
return False
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
return False
return True
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]):
if not self.supports_flipflop():
return
logging.info(f"setting up flipflop with {flipflop_blocks_per_type}")
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device)
def init_flipflop_block_copies(self) -> int:
if not self.supports_flipflop():
return 0
return self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
def clean_flipflop(self) -> int:
if not self.supports_flipflop():
return 0
return self.model.diffusion_model.clean_flipflop_holders()
def _get_existing_flipflop_prefixes(self):
if self.supports_flipflop():
return self.model.diffusion_model.flipflop_prefixes
return []
def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False):
flipflop_prefixes = []
flipflop_blocks_per_type: dict[str, tuple[int, int]] = {}
if lowvram_model_memory > 0 and self.supports_flipflop():
block_buffer = 3
valid_block_types = []
# for each block type, check if have enough room to flipflop
for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=True):
block_size: int = block_info[1]
if block_size * block_buffer < lowvram_model_memory:
valid_block_types.append(block_info)
# if have candidates for flipping, see how many of each type we have can flipflop
if len(valid_block_types) > 0:
leftover_memory = lowvram_model_memory
for block_info in valid_block_types:
block_type: str = block_info[0]
block_size: int = block_info[1]
total_blocks = len(self.model.diffusion_model.get_all_blocks(block_type))
n_fit_in_memory = int(leftover_memory // block_size)
# if all (or more) of this block type would fit in memory, no need to flipflop with it
if n_fit_in_memory >= total_blocks:
leftover_memory -= total_blocks * block_size
continue
# if the amount of this block that would fit in memory is less than buffer, skip this block type
if n_fit_in_memory < block_buffer:
continue
# 2 blocks worth of VRAM may be needed for flipflop, so make sure to account for them.
flipflop_blocks = min((total_blocks - n_fit_in_memory) + 2, total_blocks)
# for now, work around odd number issue by making it even
if flipflop_blocks % 2 != 0:
if flipflop_blocks == total_blocks:
flipflop_blocks -= 1
else:
flipflop_blocks += 1
flipflop_blocks_per_type[block_type] = (flipflop_blocks, total_blocks)
leftover_memory -= (total_blocks - flipflop_blocks + 2) * block_size
# if there are blocks to flipflop, need to mark their keys
for block_type, (flipflop_blocks, total_blocks) in flipflop_blocks_per_type.items():
# blocks to flipflop are at the end
for i in range(total_blocks-flipflop_blocks, total_blocks):
flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}")
if prepare_flipflop and len(flipflop_blocks_per_type) > 0:
self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes)
return flipflop_prefixes
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False):
loading = []
if get_existing_flipflop:
flipflop_prefixes = self._get_existing_flipflop_prefixes()
else:
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
for n, m in self.model.named_modules():
params = []
skip = False
@ -630,7 +718,12 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
flipflop = False
for prefix in flipflop_prefixes:
if n.startswith(prefix):
flipflop = True
break
loading.append((comfy.model_management.module_size(m), n, m, params, flipflop))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -639,14 +732,19 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = self._load_list()
lowvram_mem_counter = 0
flipflop_counter = 0
flipflop_mem_counter = 0
loading = self._load_list(lowvram_model_memory, prepare_flipflop=True)
load_completely = []
load_flipflop = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
flipflop: bool = x[4]
module_mem = x[0]
lowvram_weight = False
@ -654,10 +752,11 @@ class ModelPatcher:
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if not full_load and hasattr(m, "comfy_cast_weights") and not flipflop:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@ -687,7 +786,11 @@ class ModelPatcher:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
if flipflop:
flipflop_counter += 1
flipflop_mem_counter += module_mem
load_flipflop.append((module_mem, n, m, params))
elif full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
@ -703,6 +806,7 @@ class ModelPatcher:
mem_counter += move_weight_functions(m, device_to)
# handle load completely
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
@ -721,11 +825,36 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
# handle flipflop
if len(load_flipflop) > 0:
start_time = time.perf_counter()
load_flipflop.sort(reverse=True)
for x in load_flipflop:
n = x[1]
m = x[2]
params = x[3]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device)
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
end_time = time.perf_counter()
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
start_time = time.perf_counter()
mem_counter += self.init_flipflop_block_copies()
end_time = time.perf_counter()
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
if lowvram_counter > 0 or flipflop_counter > 0:
if flipflop_counter > 0:
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
else:
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@ -762,6 +891,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.clean_flipflop()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@ -801,8 +931,9 @@ class ModelPatcher:
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
memory_freed += self.clean_flipflop()
patch_counter = 0
unload_list = self._load_list()
unload_list = self._load_list(get_existing_flipflop=True)
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
@ -811,7 +942,10 @@ class ModelPatcher:
n = unload[1]
m = unload[2]
params = unload[3]
flipflop: bool = unload[4]
if flipflop:
continue
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True

View File

@ -0,0 +1,43 @@
from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class FlipFlop(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FlipFlopNew",
display_name="FlipFlop (New)",
category="_for_testing",
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="block_percentage", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output()
],
description="Apply FlipFlop transformation to model using setup_flipflop_holders method"
)
@classmethod
def execute(cls, model: io.Model.Type, block_percentage: float) -> io.NodeOutput:
# NOTE: this is just a hacky prototype still, this would not be exposed as a node.
# At the moment, this modifies the underlying model with no way to 'unpatch' it.
model = model.clone()
if not hasattr(model.model.diffusion_model, "setup_flipflop_holders"):
raise ValueError("Model does not have flipflop holders; FlipFlop not supported")
model.model.diffusion_model.setup_flipflop_holders(block_percentage)
return io.NodeOutput(model)
class FlipFlopExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FlipFlop,
]
async def comfy_entrypoint() -> FlipFlopExtension:
return FlipFlopExtension()

View File

@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_flipflop.py",
]
import_failed = []