Compare commits
25 Commits
master
...
flipflop-s
| Author | SHA1 | Date |
|---|---|---|
|
|
386e854aab | |
|
|
61133af772 | |
|
|
586a8de8da | |
|
|
5329180fce | |
|
|
0fdd327c2f | |
|
|
ee01002e63 | |
|
|
831c3cf05e | |
|
|
0d8e8abd90 | |
|
|
d5001ed90e | |
|
|
8d7b22b720 | |
|
|
6d3ec9fcf3 | |
|
|
c4420b6a41 | |
|
|
a282586995 | |
|
|
0df61b5032 | |
|
|
7c896c5567 | |
|
|
ec156e72eb | |
|
|
01f4512bf8 | |
|
|
d0bd221495 | |
|
|
8a8162e8da | |
|
|
ff789c8beb | |
|
|
0e966dcf85 | |
|
|
6b240b0bce | |
|
|
f9fbf902d5 | |
|
|
f083720eb4 | |
|
|
84e73f2aa5 |
|
|
@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--flipflop-offload", action="store_true", help="Use async flipflop weight offloading for supported DiT models.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,200 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
import copy
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class FlipFlopModule(torch.nn.Module):
|
||||
def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True):
|
||||
super().__init__()
|
||||
self.block_types = block_types
|
||||
self.enable_flipflop = enable_flipflop
|
||||
self.flipflop: dict[str, FlipFlopHolder] = {}
|
||||
self.block_info: dict[str, tuple[int, int]] = {}
|
||||
self.flipflop_prefixes: list[str] = []
|
||||
|
||||
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device):
|
||||
for block_type, (flipflop_blocks, total_blocks) in block_info.items():
|
||||
if block_type in self.flipflop:
|
||||
continue
|
||||
self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device)
|
||||
self.block_info[block_type] = (flipflop_blocks, total_blocks)
|
||||
self.flipflop_prefixes = flipflop_prefixes.copy()
|
||||
|
||||
def init_flipflop_block_copies(self, device: torch.device) -> int:
|
||||
memory_freed = 0
|
||||
for holder in self.flipflop.values():
|
||||
memory_freed += holder.init_flipflop_block_copies(device)
|
||||
return memory_freed
|
||||
|
||||
def clean_flipflop_holders(self):
|
||||
memory_freed = 0
|
||||
for block_type in list(self.flipflop.keys()):
|
||||
memory_freed += self.flipflop[block_type].clean_flipflop_blocks()
|
||||
del self.flipflop[block_type]
|
||||
self.block_info = {}
|
||||
self.flipflop_prefixes = []
|
||||
return memory_freed
|
||||
|
||||
def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]:
|
||||
return getattr(self, block_type)
|
||||
|
||||
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
|
||||
if block_type not in self.block_types:
|
||||
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
|
||||
if block_type in self.flipflop:
|
||||
return getattr(self, block_type)[:self.flipflop[block_type].i_offset]
|
||||
return getattr(self, block_type)
|
||||
|
||||
def get_all_block_module_sizes(self, reverse_sort_by_size: bool = False) -> list[tuple[str, int]]:
|
||||
'''
|
||||
Returns a list of (block_type, size) sorted by size.
|
||||
If reverse_sort_by_size is True, the list is sorted by size in reverse order.
|
||||
'''
|
||||
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
|
||||
sizes.sort(key=lambda x: x[1], reverse=reverse_sort_by_size)
|
||||
return sizes
|
||||
|
||||
def get_block_module_size(self, block_type: str) -> int:
|
||||
return comfy.model_management.module_size(getattr(self, block_type)[0])
|
||||
|
||||
def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs):
|
||||
# execute blocks, supporting both single and double (or higher) block types
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = (out,)
|
||||
for i, block in enumerate(self.get_blocks(block_type)):
|
||||
out = func(i, block, *out, *args, **kwargs)
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = (out,)
|
||||
if block_type in self.flipflop:
|
||||
holder = self.flipflop[block_type]
|
||||
with holder.context() as ctx:
|
||||
for i, block in enumerate(holder.blocks):
|
||||
out = ctx(func, i, block, *out, *args, **kwargs)
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = (out,)
|
||||
if len(out) == 1:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
|
||||
class FlipFlopContext:
|
||||
def __init__(self, holder: FlipFlopHolder):
|
||||
# NOTE: there is a bug when there are an odd number of blocks to flipflop.
|
||||
# Worked around right now by always making sure it will be even, but need to resolve.
|
||||
self.holder = holder
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.num_blocks = len(self.holder.blocks)
|
||||
self.first_flip = True
|
||||
self.first_flop = True
|
||||
self.last_flip = False
|
||||
self.last_flop = False
|
||||
|
||||
def __enter__(self):
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.holder.compute_stream.record_event(self.holder.cpy_end_event)
|
||||
|
||||
def do_flip(self, func, i: int, _, *args, **kwargs):
|
||||
# flip
|
||||
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
|
||||
with torch.cuda.stream(self.holder.compute_stream):
|
||||
out = func(i+self.holder.i_offset, self.holder.flip, *args, **kwargs)
|
||||
self.holder.event_flip.record(self.holder.compute_stream)
|
||||
# while flip executes, queue flop to copy to its next block
|
||||
next_flop_i = i + 1
|
||||
if next_flop_i >= self.num_blocks:
|
||||
next_flop_i = next_flop_i - self.num_blocks
|
||||
self.last_flip = True
|
||||
if not self.first_flip:
|
||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
|
||||
if self.last_flip:
|
||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
|
||||
self.first_flip = False
|
||||
return out
|
||||
|
||||
def do_flop(self, func, i: int, _, *args, **kwargs):
|
||||
# flop
|
||||
if not self.first_flop:
|
||||
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
|
||||
with torch.cuda.stream(self.holder.compute_stream):
|
||||
out = func(i+self.holder.i_offset, self.holder.flop, *args, **kwargs)
|
||||
self.holder.event_flop.record(self.holder.compute_stream)
|
||||
# while flop executes, queue flip to copy to its next block
|
||||
next_flip_i = i + 1
|
||||
if next_flip_i >= self.num_blocks:
|
||||
next_flip_i = next_flip_i - self.num_blocks
|
||||
self.last_flop = True
|
||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
|
||||
if self.last_flop:
|
||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
|
||||
self.first_flop = False
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs):
|
||||
# flips are even indexes, flops are odd indexes
|
||||
if i % 2 == 0:
|
||||
return self.do_flip(func, i, block, *args, **kwargs)
|
||||
else:
|
||||
return self.do_flop(func, i, block, *args, **kwargs)
|
||||
|
||||
|
||||
class FlipFlopHolder:
|
||||
def __init__(self, blocks: list[torch.nn.Module], flip_amount: int, total_amount: int, load_device: torch.device, offload_device: torch.device):
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.blocks = blocks
|
||||
self.flip_amount = flip_amount
|
||||
self.total_amount = total_amount
|
||||
# NOTE: used to make sure block indexes passed into block functions match expected patch indexes
|
||||
self.i_offset = total_amount - flip_amount
|
||||
|
||||
self.block_module_size = 0
|
||||
if len(self.blocks) > 0:
|
||||
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
|
||||
|
||||
self.flip: torch.nn.Module = None
|
||||
self.flop: torch.nn.Module = None
|
||||
|
||||
self.compute_stream = torch.cuda.default_stream(self.load_device)
|
||||
self.cpy_stream = torch.cuda.Stream(self.load_device)
|
||||
|
||||
self.event_flip = torch.cuda.Event(enable_timing=False)
|
||||
self.event_flop = torch.cuda.Event(enable_timing=False)
|
||||
self.cpy_end_event = torch.cuda.Event(enable_timing=False)
|
||||
# INIT - is this actually needed?
|
||||
self.compute_stream.record_event(self.cpy_end_event)
|
||||
|
||||
def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy_end_event: torch.cuda.Event=None):
|
||||
if cpy_start_event:
|
||||
self.cpy_stream.wait_event(cpy_start_event)
|
||||
|
||||
with torch.cuda.stream(self.cpy_stream):
|
||||
for k, v in src.items():
|
||||
dst[k].copy_(v, non_blocking=True)
|
||||
if cpy_end_event:
|
||||
cpy_end_event.record(self.cpy_stream)
|
||||
|
||||
def context(self):
|
||||
return FlipFlopContext(self)
|
||||
|
||||
def init_flipflop_block_copies(self, load_device: torch.device) -> int:
|
||||
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
|
||||
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
|
||||
return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop)
|
||||
|
||||
def clean_flipflop_blocks(self) -> int:
|
||||
memory_freed = 0
|
||||
memory_freed += comfy.model_management.module_size(self.flip)
|
||||
memory_freed += comfy.model_management.module_size(self.flop)
|
||||
del self.flip
|
||||
del self.flop
|
||||
self.flip = None
|
||||
self.flop = None
|
||||
return memory_freed
|
||||
|
|
@ -7,6 +7,7 @@ from torch import Tensor, nn
|
|||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flipflop_transformer import FlipFlopModule
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
|
|
@ -35,13 +36,13 @@ class FluxParams:
|
|||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
class Flux(FlipFlopModule):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
super().__init__(("double_blocks", "single_blocks"))
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
|
|
@ -89,6 +90,72 @@ class Flux(nn.Module):
|
|||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img[:, :add.shape[1]] += add
|
||||
return img, txt
|
||||
|
||||
def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
return img
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
|
|
@ -136,74 +203,16 @@ class Flux(nn.Module):
|
|||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img[:, :add.shape[1]] += add
|
||||
# execute double blocks
|
||||
img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options)
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
# execute single blocks
|
||||
img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options)
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,13 @@ import torch.nn.functional as F
|
|||
from typing import Optional, Tuple
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.flipflop_transformer import FlipFlopModule
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
|
|
@ -283,7 +285,7 @@ class LastLayer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(nn.Module):
|
||||
class QwenImageTransformer2DModel(FlipFlopModule):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
|
|
@ -300,9 +302,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
final_layer=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
operations: comfy.ops.disable_weight_init=None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(block_types=("transformer_blocks",))
|
||||
self.dtype = dtype
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
|
|
@ -366,6 +368,40 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
|
|
@ -433,37 +469,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
out = (hidden_states, encoder_hidden_states)
|
||||
hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flipflop_transformer import FlipFlopModule
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
import comfy.ldm.common_dit
|
||||
|
|
@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module):
|
|||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class WanModel(torch.nn.Module):
|
||||
class WanModel(FlipFlopModule):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
|
@ -412,6 +413,7 @@ class WanModel(torch.nn.Module):
|
|||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
enable_flipflop=True,
|
||||
):
|
||||
r"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
|
@ -449,7 +451,7 @@ class WanModel(torch.nn.Module):
|
|||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop)
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
|
|
@ -506,6 +508,18 @@ class WanModel(torch.nn.Module):
|
|||
else:
|
||||
self.ref_conv = None
|
||||
|
||||
def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
|
|
@ -567,16 +581,8 @@ class WanModel(torch.nn.Module):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
# execute blocks
|
||||
x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
|
@ -688,7 +694,7 @@ class VaceWanModel(WanModel):
|
|||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
# Vace
|
||||
|
|
@ -808,7 +814,7 @@ class CameraWanModel(WanModel):
|
|||
else:
|
||||
model_type = 't2v'
|
||||
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
|
|
@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel):
|
|||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||
|
||||
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
||||
|
||||
|
|
@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel):
|
|||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||
|
||||
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
|
|
|
|||
|
|
@ -426,7 +426,7 @@ class AnimateWanModel(WanModel):
|
|||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||
|
||||
self.pose_patch_embedding = operations.Conv3d(
|
||||
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
|
|
|
|||
|
|
@ -1006,6 +1006,8 @@ def force_channels_last():
|
|||
#TODO
|
||||
return False
|
||||
|
||||
def flipflop_enabled():
|
||||
return args.flipflop_offload
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 1
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import logging
|
|||
import math
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import time # TODO remove
|
||||
import torch
|
||||
|
||||
import comfy.float
|
||||
|
|
@ -591,7 +591,7 @@ class ModelPatcher:
|
|||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
|
|
@ -611,15 +611,103 @@ class ModelPatcher:
|
|||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if device_final is not None:
|
||||
out_weight = out_weight.to(device_final)
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
if device_final is not None:
|
||||
out_weight = out_weight.to(device_final)
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def _load_list(self):
|
||||
def supports_flipflop(self):
|
||||
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
|
||||
if not comfy.model_management.flipflop_enabled():
|
||||
return False
|
||||
if not hasattr(self.model, "diffusion_model"):
|
||||
return False
|
||||
if not getattr(self.model.diffusion_model, "enable_flipflop", False):
|
||||
return False
|
||||
if not comfy.model_management.is_nvidia():
|
||||
return False
|
||||
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]):
|
||||
if not self.supports_flipflop():
|
||||
return
|
||||
logging.info(f"setting up flipflop with {flipflop_blocks_per_type}")
|
||||
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device)
|
||||
|
||||
def init_flipflop_block_copies(self) -> int:
|
||||
if not self.supports_flipflop():
|
||||
return 0
|
||||
return self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
|
||||
|
||||
def clean_flipflop(self) -> int:
|
||||
if not self.supports_flipflop():
|
||||
return 0
|
||||
return self.model.diffusion_model.clean_flipflop_holders()
|
||||
|
||||
def _get_existing_flipflop_prefixes(self):
|
||||
if self.supports_flipflop():
|
||||
return self.model.diffusion_model.flipflop_prefixes
|
||||
return []
|
||||
|
||||
def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False):
|
||||
flipflop_prefixes = []
|
||||
flipflop_blocks_per_type: dict[str, tuple[int, int]] = {}
|
||||
if lowvram_model_memory > 0 and self.supports_flipflop():
|
||||
block_buffer = 3
|
||||
valid_block_types = []
|
||||
# for each block type, check if have enough room to flipflop
|
||||
for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=True):
|
||||
block_size: int = block_info[1]
|
||||
if block_size * block_buffer < lowvram_model_memory:
|
||||
valid_block_types.append(block_info)
|
||||
# if have candidates for flipping, see how many of each type we have can flipflop
|
||||
if len(valid_block_types) > 0:
|
||||
leftover_memory = lowvram_model_memory
|
||||
for block_info in valid_block_types:
|
||||
block_type: str = block_info[0]
|
||||
block_size: int = block_info[1]
|
||||
total_blocks = len(self.model.diffusion_model.get_all_blocks(block_type))
|
||||
n_fit_in_memory = int(leftover_memory // block_size)
|
||||
# if all (or more) of this block type would fit in memory, no need to flipflop with it
|
||||
if n_fit_in_memory >= total_blocks:
|
||||
leftover_memory -= total_blocks * block_size
|
||||
continue
|
||||
# if the amount of this block that would fit in memory is less than buffer, skip this block type
|
||||
if n_fit_in_memory < block_buffer:
|
||||
continue
|
||||
# 2 blocks worth of VRAM may be needed for flipflop, so make sure to account for them.
|
||||
flipflop_blocks = min((total_blocks - n_fit_in_memory) + 2, total_blocks)
|
||||
# for now, work around odd number issue by making it even
|
||||
if flipflop_blocks % 2 != 0:
|
||||
if flipflop_blocks == total_blocks:
|
||||
flipflop_blocks -= 1
|
||||
else:
|
||||
flipflop_blocks += 1
|
||||
flipflop_blocks_per_type[block_type] = (flipflop_blocks, total_blocks)
|
||||
leftover_memory -= (total_blocks - flipflop_blocks + 2) * block_size
|
||||
# if there are blocks to flipflop, need to mark their keys
|
||||
for block_type, (flipflop_blocks, total_blocks) in flipflop_blocks_per_type.items():
|
||||
# blocks to flipflop are at the end
|
||||
for i in range(total_blocks-flipflop_blocks, total_blocks):
|
||||
flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}")
|
||||
if prepare_flipflop and len(flipflop_blocks_per_type) > 0:
|
||||
self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes)
|
||||
return flipflop_prefixes
|
||||
|
||||
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False):
|
||||
loading = []
|
||||
if get_existing_flipflop:
|
||||
flipflop_prefixes = self._get_existing_flipflop_prefixes()
|
||||
else:
|
||||
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
skip = False
|
||||
|
|
@ -630,7 +718,12 @@ class ModelPatcher:
|
|||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
flipflop = False
|
||||
for prefix in flipflop_prefixes:
|
||||
if n.startswith(prefix):
|
||||
flipflop = True
|
||||
break
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params, flipflop))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
|
|
@ -639,14 +732,19 @@ class ModelPatcher:
|
|||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
loading = self._load_list()
|
||||
lowvram_mem_counter = 0
|
||||
flipflop_counter = 0
|
||||
flipflop_mem_counter = 0
|
||||
loading = self._load_list(lowvram_model_memory, prepare_flipflop=True)
|
||||
|
||||
load_completely = []
|
||||
load_flipflop = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
flipflop: bool = x[4]
|
||||
module_mem = x[0]
|
||||
|
||||
lowvram_weight = False
|
||||
|
|
@ -654,10 +752,11 @@ class ModelPatcher:
|
|||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if not full_load and hasattr(m, "comfy_cast_weights") and not flipflop:
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
|
|
@ -687,7 +786,11 @@ class ModelPatcher:
|
|||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
if flipflop:
|
||||
flipflop_counter += 1
|
||||
flipflop_mem_counter += module_mem
|
||||
load_flipflop.append((module_mem, n, m, params))
|
||||
elif full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
|
||||
|
|
@ -703,6 +806,7 @@ class ModelPatcher:
|
|||
|
||||
mem_counter += move_weight_functions(m, device_to)
|
||||
|
||||
# handle load completely
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
|
|
@ -721,11 +825,36 @@ class ModelPatcher:
|
|||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
# handle flipflop
|
||||
if len(load_flipflop) > 0:
|
||||
start_time = time.perf_counter()
|
||||
load_flipflop.sort(reverse=True)
|
||||
for x in load_flipflop:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device)
|
||||
|
||||
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
|
||||
end_time = time.perf_counter()
|
||||
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
|
||||
start_time = time.perf_counter()
|
||||
mem_counter += self.init_flipflop_block_copies()
|
||||
end_time = time.perf_counter()
|
||||
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
if lowvram_counter > 0 or flipflop_counter > 0:
|
||||
if flipflop_counter > 0:
|
||||
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
|
||||
else:
|
||||
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
|
@ -762,6 +891,7 @@ class ModelPatcher:
|
|||
self.eject_model()
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.clean_flipflop()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
|
@ -801,8 +931,9 @@ class ModelPatcher:
|
|||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
memory_freed += self.clean_flipflop()
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
unload_list = self._load_list(get_existing_flipflop=True)
|
||||
unload_list.sort()
|
||||
for unload in unload_list:
|
||||
if memory_to_free < memory_freed:
|
||||
|
|
@ -811,7 +942,10 @@ class ModelPatcher:
|
|||
n = unload[1]
|
||||
m = unload[2]
|
||||
params = unload[3]
|
||||
flipflop: bool = unload[4]
|
||||
|
||||
if flipflop:
|
||||
continue
|
||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
move_weight = True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class FlipFlop(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FlipFlopNew",
|
||||
display_name="FlipFlop (New)",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="block_percentage", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output()
|
||||
],
|
||||
description="Apply FlipFlop transformation to model using setup_flipflop_holders method"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, block_percentage: float) -> io.NodeOutput:
|
||||
# NOTE: this is just a hacky prototype still, this would not be exposed as a node.
|
||||
# At the moment, this modifies the underlying model with no way to 'unpatch' it.
|
||||
model = model.clone()
|
||||
if not hasattr(model.model.diffusion_model, "setup_flipflop_holders"):
|
||||
raise ValueError("Model does not have flipflop holders; FlipFlop not supported")
|
||||
model.model.diffusion_model.setup_flipflop_holders(block_percentage)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class FlipFlopExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
FlipFlop,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> FlipFlopExtension:
|
||||
return FlipFlopExtension()
|
||||
Loading…
Reference in New Issue