Compare commits
2 Commits
master
...
release/v0
| Author | SHA1 | Date |
|---|---|---|
|
|
59afc39848 | |
|
|
028e17dd7a |
|
|
@ -53,16 +53,6 @@ try:
|
|||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
|
|
@ -76,10 +66,8 @@ if branch is None:
|
|||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("fetching.") # noqa: T201
|
||||
for remote in repo.remotes:
|
||||
if remote.name == "origin":
|
||||
remote.fetch()
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
|
|
@ -161,4 +149,3 @@ try:
|
|||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ on:
|
|||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ name: Execution Tests
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ name: Test server launches without errors
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ name: Unit Tests
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ on:
|
|||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
* @comfyanonymous
|
||||
* @kosinkadink
|
||||
|
|
|
|||
32
README.md
32
README.md
|
|
@ -81,7 +81,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
|
|
@ -119,9 +118,6 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
|||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
|
|
@ -212,8 +208,6 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
|||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
|
@ -325,32 +319,6 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
|||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||
|
||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install the manager dependencies:
|
||||
```bash
|
||||
pip install -r manager_requirements.txt
|
||||
```
|
||||
|
||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||
```bash
|
||||
python main.py --enable-manager
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
|
|
|||
|
|
@ -58,13 +58,8 @@ class InternalRoutes:
|
|||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class ModelFileManager:
|
|||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
|
|
@ -55,7 +55,7 @@ class ModelFileManager:
|
|||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
|
|
|||
|
|
@ -97,13 +97,6 @@ class LatentPreviewMethod(enum.Enum):
|
|||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
|
@ -128,12 +121,6 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
|||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
|
|
@ -181,7 +168,6 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,25 +2,6 @@ import torch
|
|||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
|
@ -16,7 +17,24 @@ class Output:
|
|||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
|
|
@ -55,7 +73,7 @@ class ClipVisionModel():
|
|||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
|
|
|
|||
|
|
@ -51,43 +51,32 @@ class ContextHandlerABC(ABC):
|
|||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
window[idx] = full[idx]
|
||||
return window.to(device)
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
def get_region_index(self, num_regions: int) -> int:
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
|
|
@ -105,8 +94,7 @@ class ContextFuseMethod:
|
|||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
|
|
@ -115,18 +103,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -140,11 +123,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
|
|
@ -167,38 +145,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# Handle vace_context (temporal dim is 3)
|
||||
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
vace_cond = cond_value.cond
|
||||
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
|
|
@ -211,7 +164,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
|
|
@ -220,7 +173,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
|
|
@ -297,8 +250,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
|
|
@ -334,28 +287,6 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
|||
)
|
||||
|
||||
|
||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||
model_options = extra_args.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is None:
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
"ContextWindows_sampler_sample",
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
|
|
@ -607,29 +538,3 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
|||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
|
||||
|
||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||
logging.info("Context windows: Applying FreeNoise")
|
||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||
latent_video_length = noise.shape[dim]
|
||||
delta = context_length - context_overlap
|
||||
|
||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||
place_idx = start_idx + context_length
|
||||
|
||||
actual_delta = min(delta, latent_video_length - place_idx)
|
||||
if actual_delta <= 0:
|
||||
break
|
||||
|
||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||
|
||||
source_slice = [slice(None)] * noise.ndim
|
||||
source_slice[dim] = list_idx
|
||||
target_slice = [slice(None)] * noise.ndim
|
||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||
|
||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||
|
||||
return noise
|
||||
|
|
|
|||
|
|
@ -527,8 +527,7 @@ class HookKeyframeGroup:
|
|||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else:
|
||||
break
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
|
|
|||
|
|
@ -74,9 +74,6 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
|||
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
if x.device == torch.device("cpu"):
|
||||
seed += 1
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
|
|
@ -1560,13 +1557,10 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
|
@ -1606,14 +1600,8 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
|
@ -1621,17 +1609,6 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
|
|
@ -1779,7 +1756,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x_pred = denoised
|
||||
x = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
|
|
@ -1800,7 +1777,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x_pred
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
|||
|
|
@ -40,8 +40,7 @@ class ChromaParams:
|
|||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
|
|||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
|
|
@ -159,9 +159,6 @@ class ChromaRadiance(Chroma):
|
|||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
|
|
@ -270,7 +267,7 @@ class ChromaRadiance(Chroma):
|
|||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
|
|
@ -279,12 +276,6 @@ class ChromaRadiance(Chroma):
|
|||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
|
|
@ -325,11 +316,4 @@ class ChromaRadiance(Chroma):
|
|||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
|
|
|||
|
|
@ -57,35 +57,6 @@ class MLPEmbedder(nn.Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
class YakMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
if yak_mlp:
|
||||
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
|
|
@ -169,7 +140,7 @@ class SiLUActivation(nn.Module):
|
|||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
|
@ -185,7 +156,18 @@ class DoubleStreamBlock(nn.Module):
|
|||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
|
@ -195,7 +177,18 @@ class DoubleStreamBlock(nn.Module):
|
|||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
|
|
@ -282,7 +275,6 @@ class SingleStreamBlock(nn.Module):
|
|||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
yak_mlp=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
|
|
@ -296,17 +288,12 @@ class SingleStreamBlock(nn.Module):
|
|||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
self.yak_mlp = yak_mlp
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
if self.yak_mlp:
|
||||
self.mlp_hidden_dim_first *= 2
|
||||
self.mlp_act = nn.SiLU()
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
|
|
@ -338,10 +325,7 @@ class SingleStreamBlock(nn.Module):
|
|||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
else:
|
||||
mlp = self.mlp_act(mlp)
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from .layers import (
|
|||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
Modulation
|
||||
)
|
||||
|
||||
@dataclass
|
||||
|
|
@ -35,14 +34,11 @@ class FluxParams:
|
|||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
yak_mlp: bool = False
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
|
|
@ -80,11 +76,6 @@ class Flux(nn.Module):
|
|||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
|
|
@ -95,7 +86,6 @@ class Flux(nn.Module):
|
|||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
yak_mlp=params.yak_mlp,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
|
|
@ -104,7 +94,7 @@ class Flux(nn.Module):
|
|||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
|
@ -160,8 +150,6 @@ class Flux(nn.Module):
|
|||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.txt_norm is not None:
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
|
|
@ -344,9 +332,8 @@ class Flux(nn.Module):
|
|||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.txt_ids_dims) > 0:
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
if len(self.params.axes_dim) == 4: # Flux 2
|
||||
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ class HunyuanVideoParams:
|
|||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
meanflow_sum: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
|
|
@ -318,7 +317,7 @@ class HunyuanVideo(nn.Module):
|
|||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||
vec = (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management
|
||||
import model_patcher
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
||||
import model_management, model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,42 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class NoPadConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
if isinstance(op, NoPadConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
|
|
@ -19,7 +49,7 @@ class RMS_norm(nn.Module):
|
|||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
|
|
@ -79,7 +109,7 @@ class DnSmpl(nn.Module):
|
|||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||
|
|
@ -133,6 +163,23 @@ class UpSmpl(nn.Module):
|
|||
|
||||
return h + x
|
||||
|
||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
||||
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
h = x
|
||||
h = [ self.swish(self.norm1(x)) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||
|
|
@ -144,7 +191,7 @@ class Encoder(nn.Module):
|
|||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
conv_op = NoPadConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
|
|
@ -159,10 +206,9 @@ class Encoder(nn.Module):
|
|||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
|
|
@ -172,9 +218,9 @@ class Encoder(nn.Module):
|
|||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
|
|
@ -200,20 +246,22 @@ class Encoder(nn.Module):
|
|||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
|
|
@ -240,7 +288,7 @@ class Decoder(nn.Module):
|
|||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
conv_op = NoPadConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
|
|
@ -250,9 +298,9 @@ class Decoder(nn.Module):
|
|||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
|
@ -260,10 +308,9 @@ class Decoder(nn.Module):
|
|||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
|
|
@ -293,7 +340,7 @@ class Decoder(nn.Module):
|
|||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
|
|
@ -303,7 +350,10 @@ class Decoder(nn.Module):
|
|||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
|
|
|
|||
|
|
@ -1,413 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
original_dims = x.ndim
|
||||
if original_dims == 4:
|
||||
x = x.unsqueeze(2)
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
if original_dims == 4:
|
||||
out = out.squeeze(2)
|
||||
return out
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .model import JointTransformerBlock
|
||||
|
||||
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
block_id=0,
|
||||
operation_settings=None,
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c_skip, c
|
||||
|
||||
class ZImage_Control(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
i,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=i,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for i in range(self.n_control_layers)
|
||||
]
|
||||
)
|
||||
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
pH = pW = patch_size
|
||||
B, C, H, W = control_context.shape
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
|
|
@ -22,10 +22,6 @@ def modulate(x, scale):
|
|||
# Core NextDiT Model #
|
||||
#############################################################################
|
||||
|
||||
def clamp_fp16(x):
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
"""Multi-head attention module."""
|
||||
|
|
@ -173,7 +169,7 @@ class FeedForward(nn.Module):
|
|||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return clamp_fp16(F.silu(x1) * x3)
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
|
@ -277,27 +273,27 @@ class JointTransformerBlock(nn.Module):
|
|||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
clamp_fp16(self.feed_forward(
|
||||
self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
x = x + self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
|
|
@ -377,7 +373,6 @@ class NextDiT(nn.Module):
|
|||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
|
@ -448,31 +443,6 @@ class NextDiT(nn.Module):
|
|||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
|
|
@ -491,8 +461,7 @@ class NextDiT(nn.Module):
|
|||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
|
|
@ -537,7 +506,6 @@ class NextDiT(nn.Module):
|
|||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
|
|
@ -574,21 +542,13 @@ class NextDiT(nn.Module):
|
|||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
|
||||
padded_img_mask = None
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
|
|
@ -604,7 +564,7 @@ class NextDiT(nn.Module):
|
|||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
|
|
@ -621,36 +581,16 @@ class NextDiT(nn.Module):
|
|||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
return -img
|
||||
return -x
|
||||
|
||||
|
|
|
|||
|
|
@ -30,13 +30,6 @@ except ImportError as e:
|
|||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
|
@ -524,7 +517,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
|
|
@ -549,8 +541,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
if exception_fallback:
|
||||
if tensor_layout == "NHD":
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
|
|
@ -570,93 +560,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if (q.device.type != "cuda" or
|
||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||
mask is not None):
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
B, H, L, D = q.shape
|
||||
if H != heads:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
B, N, inner_dim = q.shape
|
||||
if inner_dim % heads != 0:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
dim_head = inner_dim // heads
|
||||
|
||||
if dim_head >= 256 or N <= 1024:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||
except Exception as e:
|
||||
exception_fallback = True
|
||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||
|
||||
if exception_fallback:
|
||||
if not skip_reshape:
|
||||
del q_s, k_s, v_s
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
if not skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
pass
|
||||
else:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
|
|
@ -744,8 +647,6 @@ optimized_attention_masked = optimized_attention
|
|||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
|
|
|
|||
|
|
@ -13,12 +13,6 @@ if model_management.xformers_enabled_vae():
|
|||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
|
|
@ -49,37 +43,6 @@ def Normalize(in_channels, num_groups=32):
|
|||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class CarriedConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if isinstance(op, CarriedConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class VideoConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
|
|
@ -126,24 +89,29 @@ class Upsample(nn.Module):
|
|||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x):
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
results = []
|
||||
if conv_carry_in is None:
|
||||
first = x[:, :, :1, :, :]
|
||||
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||
x = x[:, :, 1:, :, :]
|
||||
if x.shape[2] > 0:
|
||||
results.append(interpolate_up(x, scale_factor))
|
||||
x = torch_cat_if_needed(results, dim=2)
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -159,20 +127,17 @@ class Downsample(nn.Module):
|
|||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if isinstance(self.conv, CarriedConv3d):
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
elif x.ndim == 4:
|
||||
if x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
x = self.conv(x)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
|
@ -218,23 +183,23 @@ class ResnetBlock(nn.Module):
|
|||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
|
||||
def forward(self, x, temb=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = [ self.swish(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
h = self.swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
h = [ self.dropout(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
|
|
@ -314,7 +279,6 @@ def pytorch_attention(q, k, v):
|
|||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
oom_fallback = False
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
|
|
@ -325,8 +289,6 @@ def pytorch_attention(q, k, v):
|
|||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
oom_fallback = True
|
||||
if oom_fallback:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||
return out
|
||||
|
||||
|
|
@ -394,8 +356,7 @@ class Model(nn.Module):
|
|||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
|
@ -549,22 +510,16 @@ class Encoder(nn.Module):
|
|||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
if not attn_resolutions:
|
||||
conv_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
|
|
@ -577,7 +532,6 @@ class Encoder(nn.Module):
|
|||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.time_compress = 1
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
|
|
@ -604,15 +558,10 @@ class Encoder(nn.Module):
|
|||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
else:
|
||||
self.time_compress *= 2
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
if time_compress is not None:
|
||||
self.time_compress = time_compress
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
|
|
@ -638,42 +587,15 @@ class Encoder(nn.Module):
|
|||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
if self.carried:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.time_compress:
|
||||
tc = self.time_compress
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
# downsampling
|
||||
x1 = [ x1 ]
|
||||
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.down[i_level].attn[i_block](h1)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = torch_cat_if_needed(out, dim=2)
|
||||
del out
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
|
|
@ -682,15 +604,15 @@ class Encoder(nn.Module):
|
|||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = [ nonlinearity(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv_out)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, tanh_out=False, use_linear_attn=False,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
|
|
@ -704,18 +626,12 @@ class Decoder(nn.Module):
|
|||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||
conv_op = CarriedConv3d
|
||||
conv_out_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
|
|
@ -790,43 +706,29 @@ class Decoder(nn.Module):
|
|||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = conv_carry_causal_3d([z], self.conv_in)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
if self.carried:
|
||||
h = torch.split(h, 2, dim=2)
|
||||
else:
|
||||
h = [ h ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
# upsampling
|
||||
for i, h1 in enumerate(h):
|
||||
conv_carry_out = []
|
||||
if i == len(h) - 1:
|
||||
conv_carry_out = None
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||
if i_level != 0:
|
||||
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
h1 = self.norm_out(h1)
|
||||
h1 = [ nonlinearity(h1) ]
|
||||
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
if self.tanh_out:
|
||||
h1 = torch.tanh(h1)
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
return out
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
|||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
|
|
@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
|||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
|||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
|
|
@ -72,19 +72,9 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
|||
operations=operations
|
||||
)
|
||||
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if self.use_additional_t_cond:
|
||||
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
def forward(self, timestep, hidden_states):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
|
|
@ -228,24 +218,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|||
operations=operations,
|
||||
)
|
||||
|
||||
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
||||
if timestep_zero_index is not None:
|
||||
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
||||
else:
|
||||
return torch.addcmul(y, gate, x)
|
||||
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
if timestep_zero_index is not None:
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
||||
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
||||
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
||||
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
||||
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
||||
else:
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -254,19 +229,14 @@ class QwenImageTransformerBlock(nn.Module):
|
|||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
|
|
@ -281,15 +251,15 @@ class QwenImageTransformerBlock(nn.Module):
|
|||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
||||
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
|
@ -330,11 +300,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
default_ref_method="index",
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
use_additional_t_cond=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
|
|
@ -345,14 +314,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.default_ref_method = default_ref_method
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
use_additional_t_cond=use_additional_t_cond,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
|
|
@ -374,9 +341,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
if self.default_ref_method == "index_timestep_zero":
|
||||
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
|
@ -386,33 +350,27 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
t_len = t
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
if t_len > 1:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||
else:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
|
|
@ -420,8 +378,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
additional_t_cond=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
|
|
@ -433,24 +391,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||
negative_ref_method = ref_method == "negative_index"
|
||||
timestep_zero = ref_method == "index_timestep_zero"
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif negative_ref_method:
|
||||
index -= 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
|
|
@ -465,10 +415,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
|
|
@ -480,7 +426,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
|
@ -493,7 +446,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
|
|
@ -505,7 +458,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep_zero_index=timestep_zero_index,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
|
@ -522,12 +474,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
|||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if "target" not in config:
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
|
|
|||
|
|
@ -568,10 +568,7 @@ class WanModel(torch.nn.Module):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
|
@ -766,10 +763,7 @@ class VaceWanModel(WanModel):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
|
@ -868,10 +862,7 @@ class CameraWanModel(WanModel):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
|
@ -1335,19 +1326,16 @@ class WanModel_S2V(WanModel):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
if audio_emb is not None:
|
||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||
# head
|
||||
|
|
@ -1586,10 +1574,7 @@ class HumoWanModel(WanModel):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
|
|
|||
|
|
@ -523,10 +523,7 @@ class AnimateWanModel(WanModel):
|
|||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
|
|
|||
|
|
@ -227,7 +227,6 @@ class Encoder3d(nn.Module):
|
|||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
input_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
|
|
@ -246,7 +245,7 @@ class Encoder3d(nn.Module):
|
|||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
|
|
@ -332,7 +331,6 @@ class Decoder3d(nn.Module):
|
|||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
output_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
|
|
@ -380,7 +378,7 @@ class Decoder3d(nn.Module):
|
|||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
|
|
@ -451,7 +449,6 @@ class WanVAE(nn.Module):
|
|||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -463,11 +460,11 @@ class WanVAE(nn.Module):
|
|||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
|
|
|
|||
|
|
@ -320,16 +320,8 @@ def model_lora_keys_unet(model, key_map={}):
|
|||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ import comfy.ldm.chroma_radiance.model
|
|||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
|
@ -135,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
|||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
|
|
@ -330,6 +329,18 @@ class BaseModel(torch.nn.Module):
|
|||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
|
|
@ -1110,10 +1121,6 @@ class Lumina2(BaseModel):
|
|||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
|
|
@ -1635,49 +1642,3 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
|||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
class Kandinsky5(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||
if time_dim_replace is not None:
|
||||
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||
|
||||
return out
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -6,6 +6,20 @@ import math
|
|||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
|
|
@ -180,10 +194,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["meanflow_sum"] = True
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
|
|
@ -196,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["theta"] = 2000
|
||||
dit_config["out_channels"] = 128
|
||||
dit_config["global_modulation"] = True
|
||||
dit_config["vec_in_dim"] = None
|
||||
dit_config["mlp_silu_act"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["ops_bias"] = False
|
||||
dit_config["default_ref_method"] = "index"
|
||||
dit_config["ref_index_scale"] = 10.0
|
||||
dit_config["txt_ids_dims"] = [3]
|
||||
patch_size = 1
|
||||
else:
|
||||
dit_config["image_model"] = "flux"
|
||||
|
|
@ -211,7 +223,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["theta"] = 10000
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["txt_ids_dims"] = []
|
||||
patch_size = 2
|
||||
|
||||
dit_config["in_channels"] = 16
|
||||
|
|
@ -234,8 +245,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
|
|
@ -259,17 +268,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
|
|
@ -429,10 +429,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
|
|
@ -619,29 +615,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||
dit_config["use_additional_t_cond"] = True
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
if model_dim == 2560: # lite image
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
|
|
@ -786,11 +759,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
if quant_config:
|
||||
model_config.quant_config = quant_config
|
||||
logging.info("Detected mixed precision quantization")
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
return model_config
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ import importlib
|
|||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
|
|
@ -334,15 +333,13 @@ except:
|
|||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
|
|
@ -1019,8 +1016,8 @@ NUM_STREAMS = 0
|
|||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia and AMD
|
||||
if is_nvidia() or is_amd():
|
||||
# Enable by default on Nvidia
|
||||
if is_nvidia():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
|
|
@ -1126,16 +1123,6 @@ if not args.disable_pinned_memory:
|
|||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
|
|
@ -1168,9 +1155,6 @@ def pin_memory(tensor):
|
|||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
else:
|
||||
logging.warning("Pin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
|
|
@ -1199,9 +1183,6 @@ def unpin_memory(tensor):
|
|||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
|
|
@ -1511,20 +1492,6 @@ def extended_fp16_support():
|
|||
|
||||
return True
|
||||
|
||||
LORA_COMPUTE_DTYPES = {}
|
||||
def lora_compute_dtype(device):
|
||||
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
|
||||
if should_use_fp16(device):
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
|
@ -1542,10 +1509,6 @@ def soft_empty_cache(force=False):
|
|||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ import comfy.model_management
|
|||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
|
|
@ -127,23 +126,36 @@ class LowVramPatch:
|
|||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.convert_func = convert_func
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
intermediate_dtype = weight.dtype
|
||||
if self.convert_func is not None:
|
||||
weight = self.convert_func(weight, inplace=False)
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is None:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
else:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is not None:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||
else:
|
||||
return out
|
||||
|
||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
|
||||
if model_dtype is None:
|
||||
model_dtype = weight.dtype
|
||||
|
||||
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
|
|
@ -454,9 +466,6 @@ class ModelPatcher:
|
|||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
|
|
@ -621,11 +630,10 @@ class ModelPatcher:
|
|||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(temp_dtype, copy=True)
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
|
|
@ -669,18 +677,12 @@ class ModelPatcher:
|
|||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
def check_module_offload_mem(key):
|
||||
if key in self.patches:
|
||||
return low_vram_patch_estimate_vram(self.model, key)
|
||||
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if model_dtype is None or weight is None:
|
||||
return 0
|
||||
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||
return weight.numel() * model_dtype.itemsize
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if weight_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||
if bias_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
|
|
@ -697,12 +699,12 @@ class ModelPatcher:
|
|||
offloaded = []
|
||||
offload_buffer = 0
|
||||
loading.sort(reverse=True)
|
||||
for i, x in enumerate(loading):
|
||||
for x in loading:
|
||||
module_offload_mem, module_mem, n, m, params = x
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
|
||||
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
|
|
@ -775,8 +777,6 @@ class ModelPatcher:
|
|||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
|
|
@ -876,18 +876,14 @@ class ModelPatcher:
|
|||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
|
||||
offload_buffer = self.model.model_offload_buffer_memory
|
||||
if len(unload_list) > 0:
|
||||
NS = comfy.model_management.NUM_STREAMS
|
||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
break
|
||||
module_offload_mem, module_mem, n, m, params = unload
|
||||
|
||||
potential_offload = module_offload_mem + sum(offload_weight_factor)
|
||||
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
||||
|
||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
|
|
@ -933,14 +929,12 @@ class ModelPatcher:
|
|||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
if cast_weight:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
offload_weight_factor.append(module_mem)
|
||||
offload_weight_factor.pop(0)
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
|
|
|
|||
219
comfy/ops.py
219
comfy/ops.py
|
|
@ -22,7 +22,7 @@ import comfy.model_management
|
|||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import contextlib
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
|
|
@ -93,6 +93,13 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||
else:
|
||||
offload_stream = None
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
|
|
@ -104,24 +111,22 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
bias_a = bias
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
if bias_has_function:
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
if weight_has_function or weight.dtype != dtype:
|
||||
weight = weight.to(dtype=dtype)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
with wf_context:
|
||||
weight = weight.to(dtype=dtype)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
return weight, bias, offload_stream
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
|
|
@ -130,16 +135,13 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
if weight is not None:
|
||||
device = weight.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
if bias is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
device = bias.device
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
class CastWeightBiasOp:
|
||||
|
|
@ -415,12 +417,22 @@ def fp8_linear(self, input):
|
|||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
|
|
@ -441,7 +453,7 @@ class fp8_ops(manual_cast):
|
|||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
||||
if not self.training:
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
|
|
@ -454,6 +466,59 @@ class fp8_ops(manual_cast):
|
|||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if return_weight:
|
||||
return weight
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
|
|
@ -480,9 +545,9 @@ if CUBLAS_IS_AVAILABLE:
|
|||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
_layer_quant_config = layer_quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
|
||||
|
|
@ -497,14 +562,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if dtype is None:
|
||||
dtype = MixedPrecisionOps._compute_dtype
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._has_bias = bias
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
|
|
@ -524,59 +590,36 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
dtype = self.factory_kwargs["dtype"]
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
||||
if dtype != MixedPrecisionOps._compute_dtype:
|
||||
self.comfy_cast_weights = True
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||
|
||||
if self.quant_format is None:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(weight_scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.to(device)
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
|
||||
if scale is not None:
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
|
@ -585,16 +628,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
|
@ -610,8 +643,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
|
|
@ -622,7 +656,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
|
@ -631,28 +665,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
|
|
|||
|
|
@ -238,9 +238,6 @@ class QuantizedTensor(torch.Tensor):
|
|||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
def storage(self):
|
||||
return self._qdata.storage()
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# ==============================================================================
|
||||
|
|
@ -252,6 +249,12 @@ def _create_transformed_qtensor(qt, transform_fn):
|
|||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
|
|
@ -271,8 +274,6 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
if target_dtype is not None:
|
||||
new_params["orig_dtype"] = target_dtype
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
|
@ -338,9 +339,7 @@ def generic_copy_(func, args, kwargs):
|
|||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
|
|
@ -398,23 +397,17 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if isinstance(scale, str) and scale == "recalculate":
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||
tensor_info = torch.finfo(tensor.dtype)
|
||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||
|
||||
if scale is not None:
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
|
|
|
|||
|
|
@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
|||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
|
|
|||
|
|
@ -720,7 +720,7 @@ class Sampler:
|
|||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
|
|
@ -984,6 +984,9 @@ class CFGGuider:
|
|||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
|
@ -1010,24 +1013,6 @@ class CFGGuider:
|
|||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
if denoise_mask is not None:
|
||||
if denoise_mask.is_nested:
|
||||
denoise_masks = denoise_mask.unbind()
|
||||
denoise_masks = denoise_masks[:len(latent_shapes)]
|
||||
else:
|
||||
denoise_masks = [denoise_mask]
|
||||
|
||||
for i in range(len(denoise_masks), len(latent_shapes)):
|
||||
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||
|
||||
for i in range(len(denoise_masks)):
|
||||
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||
|
||||
if len(denoise_masks) > 1:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
|
|
|||
210
comfy/sd.py
210
comfy/sd.py
|
|
@ -53,10 +53,6 @@ import comfy.text_encoders.omnigen2
|
|||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
|
|
@ -101,7 +97,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
|
|
@ -129,32 +125,9 @@ class CLIP:
|
|||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if len(state_dict) > 0:
|
||||
if isinstance(state_dict, list):
|
||||
for c in state_dict:
|
||||
m, u = self.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
else:
|
||||
m, u = self.load_sd(state_dict, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
|
|
@ -219,7 +192,6 @@ class CLIP:
|
|||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
|
|
@ -267,7 +239,6 @@ class CLIP:
|
|||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
|
|
@ -323,7 +294,6 @@ class VAE:
|
|||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
|
@ -438,7 +408,6 @@ class VAE:
|
|||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
|
|
@ -499,7 +468,7 @@ class VAE:
|
|||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
|
|
@ -511,10 +480,8 @@ class VAE:
|
|||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||
|
|
@ -550,15 +517,11 @@ class VAE:
|
|||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
||||
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
# Hunyuan 3d v2 2.0 & 2.1
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
|
||||
|
|
@ -588,7 +551,6 @@ class VAE:
|
|||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
|
|
@ -697,28 +659,17 @@ class VAE:
|
|||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if self.crop_input:
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
if not self.crop_input:
|
||||
return pixels
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
if pixels.shape[-1] > self.output_channels:
|
||||
pixels = pixels[..., :self.output_channels]
|
||||
elif pixels.shape[-1] < self.output_channels:
|
||||
if self.pad_channel_value is not None:
|
||||
if isinstance(self.pad_channel_value, str):
|
||||
mode = self.pad_channel_value
|
||||
value = None
|
||||
else:
|
||||
mode = "constant"
|
||||
value = self.pad_channel_value
|
||||
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
|
|
@ -789,8 +740,6 @@ class VAE:
|
|||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
|
|
@ -1007,18 +956,16 @@ class CLIPType(Enum):
|
|||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
NEWBIE = 24
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if model_options.get("custom_operations", None) is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||
if metadata is not None:
|
||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
||||
if quant_metadata is not None:
|
||||
sd["_quantization_metadata"] = quant_metadata
|
||||
clip_data.append(sd)
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
|
@ -1040,8 +987,6 @@ class TEModel(Enum):
|
|||
MISTRAL3_24B = 14
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
JINA_CLIP_2 = 18
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
|
|
@ -1051,8 +996,6 @@ def detect_te_model(sd):
|
|||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
|
|
@ -1077,12 +1020,9 @@ def detect_te_model(sd):
|
|||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
return TEModel.QWEN3_4B
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
|
|
@ -1138,7 +1078,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
|
|
@ -1162,7 +1102,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
|
|
@ -1191,7 +1131,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
|
|
@ -1210,19 +1150,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||
elif te_model == TEModel.QWEN3_4B:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
|
|
@ -1265,23 +1199,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
|
||||
clip_data_gemma = clip_data[0]
|
||||
clip_data_jina = clip_data[1]
|
||||
else:
|
||||
clip_data_gemma = clip_data[1]
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
|
@ -1294,10 +1211,19 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
if "_quantization_metadata" in c:
|
||||
c.pop("_quantization_metadata")
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
|
|
@ -1356,10 +1282,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
|
|
@ -1368,22 +1290,18 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
|
|
@ -1401,33 +1319,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
scaled_fp8_list = []
|
||||
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||
if k.endswith(".scaled_fp8"):
|
||||
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||
|
||||
if len(scaled_fp8_list) > 0:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
|
|
@ -1474,9 +1381,6 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
|
|
@ -1507,7 +1411,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if dtype is None:
|
||||
|
|
@ -1515,15 +1419,12 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
if model_config.quant_config is not None:
|
||||
if model_config.layer_quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
|
|
@ -1562,9 +1463,6 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
|
|
|
|||
|
|
@ -107,17 +107,29 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
config[k] = v
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
quant_config = model_options.get("quantization_metadata", None)
|
||||
scaled_fp8 = None
|
||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
||||
|
||||
if operations is None:
|
||||
if quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
logging.info("Using MixedPrecisionOps for text encoder")
|
||||
layer_quant_config = None
|
||||
if quantization_metadata is not None:
|
||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
||||
|
||||
if layer_quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
# Fallback to scaled_fp8_ops for backward compatibility
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
|
|
@ -135,7 +147,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
self.return_attention_masks = return_attention_masks
|
||||
self.execution_device = None
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
|
|
@ -152,7 +163,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
pass
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
|
|
@ -165,7 +175,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
self.layer = self.options_default[0]
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
self.execution_device = None
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
|
|
@ -249,11 +258,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||
|
||||
def forward(self, tokens):
|
||||
if self.execution_device is None:
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
else:
|
||||
device = self.execution_device
|
||||
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
|
|
@ -466,7 +471,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
|
|
@ -513,8 +518,6 @@ class SDTokenizer:
|
|||
self.embedding_size = embedding_size
|
||||
self.embedding_key = embedding_key
|
||||
|
||||
self.disable_weights = disable_weights
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
|
|
@ -549,7 +552,7 @@ class SDTokenizer:
|
|||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
if kwargs.get("disable_weights", self.disable_weights):
|
||||
if kwargs.get("disable_weights", False):
|
||||
parsed_weights = [(text, 1.0)]
|
||||
else:
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
|
|
|||
|
|
@ -21,14 +21,12 @@ import comfy.text_encoders.ace
|
|||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
||||
from . import diffusers_convert
|
||||
import comfy.model_management
|
||||
|
||||
class SD15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
|
@ -542,7 +540,7 @@ class SD3(supported_models_base.BASE):
|
|||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
|
||||
memory_usage_factor = 1.6
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
|
|
@ -966,7 +964,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
|||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
|
|
@ -1027,15 +1025,7 @@ class ZImage(Lumina2):
|
|||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||
memory_usage_factor = 1.7
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
|
|
@ -1296,7 +1286,7 @@ class ChromaRadiance(Chroma):
|
|||
latent_format = comfy.latent_formats.ChromaRadiance
|
||||
|
||||
# Pixel-space model, no spatial compression for model input.
|
||||
memory_usage_factor = 0.044
|
||||
memory_usage_factor = 0.038
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ChromaRadiance(self, device=device)
|
||||
|
|
@ -1339,7 +1329,7 @@ class Omnigen2(supported_models_base.BASE):
|
|||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.95 #TODO
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
|
@ -1404,7 +1394,7 @@ class HunyuanImage21(HunyuanVideo):
|
|||
|
||||
latent_format = latent_formats.HunyuanImage21
|
||||
|
||||
memory_usage_factor = 8.7
|
||||
memory_usage_factor = 7.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
|
|
@ -1482,60 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
|||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
||||
|
||||
class Kandinsky5(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 10.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
"model_dim": 2560,
|
||||
"visual_embed_dim": 64,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Flux
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5Image(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
|
|
@ -50,7 +49,8 @@ class BASE:
|
|||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
quant_config = None # quantization configuration for mixed precision
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -118,7 +118,3 @@ class BASE:
|
|||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
def __getattr__(self, name):
|
||||
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -154,8 +154,7 @@ class TAEHV(nn.Module):
|
|||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
|
|
@ -168,6 +167,5 @@ class TAEHV(nn.Module):
|
|||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
|
|||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||
if t5xxl_quantization_metadata is not None:
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||
|
||||
|
|
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class CosmosTEModel_(CosmosT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
|
|||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
|
|
@ -159,13 +159,15 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
|||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||
return out, pooled, extra
|
||||
|
||||
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Flux2TEModel_(Flux2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if pruned:
|
||||
model_options = model_options.copy()
|
||||
|
|
|
|||
|
|
@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class MochiTEModel_(MochiT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
|
|||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||
class HiDreamTEModel_(HiDreamTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HiDreamTEModel_
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
|||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
|
|
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
|||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
|
|||
import torch
|
||||
import os
|
||||
import numbers
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
|
|
@ -14,9 +14,12 @@ def llama_detect(state_dict, prefix=""):
|
|||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["llama_quantization_metadata"] = quant
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
if "_quantization_metadata" in state_dict:
|
||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
||||
|
||||
return out
|
||||
|
||||
|
|
@ -28,10 +31,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
|||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
|
||||
textmodel_json_config = {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
|
|
@ -158,11 +161,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HunyuanVideoClipModel_
|
||||
|
|
|
|||
|
|
@ -1,219 +0,0 @@
|
|||
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
|
||||
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
|
||||
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
|
||||
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||
|
||||
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||
@dataclass
|
||||
class XLMRobertaConfig:
|
||||
vocab_size: int = 250002
|
||||
type_vocab_size: int = 1
|
||||
hidden_size: int = 1024
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 16
|
||||
rotary_emb_base: float = 20000.0
|
||||
intermediate_size: int = 4096
|
||||
hidden_act: str = "gelu"
|
||||
hidden_dropout_prob: float = 0.1
|
||||
attention_probs_dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-05
|
||||
bos_token_id: int = 0
|
||||
eos_token_id: int = 2
|
||||
pad_token_id: int = 1
|
||||
|
||||
class XLMRobertaEmbeddings(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
|
||||
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, input_ids=None, embeddings=None):
|
||||
if input_ids is not None and embeddings is None:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if embeddings is not None:
|
||||
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = embeddings + token_type_embeddings
|
||||
return embeddings
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, base, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self._cos_cached = emb.cos().to(dtype)
|
||||
self._sin_cached = emb.sin().to(dtype)
|
||||
|
||||
def forward(self, q, k):
|
||||
batch, seqlen, heads, head_dim = q.shape
|
||||
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
|
||||
|
||||
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||
|
||||
def rotate_half(x):
|
||||
size = x.shape[-1] // 2
|
||||
x1, x2 = x[..., :size], x[..., size:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
class MHA(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = embed_dim // config.num_attention_heads
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
|
||||
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
qkv = self.Wqkv(x)
|
||||
batch_size, seq_len, _ = qkv.shape
|
||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
|
||||
q, k = self.rotary_emb(q, k)
|
||||
|
||||
# NHD -> HND
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||
return self.out_proj(out)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
||||
self.activation = F.gelu
|
||||
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
|
||||
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, mask=None, optimized_attention=None):
|
||||
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
|
||||
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class XLMRobertaEncoder(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||
return hidden_states
|
||||
|
||||
class XLMRobertaModel_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
|
||||
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||
x = self.emb_ln(x)
|
||||
x = self.emb_drop(x)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
sequence_output = self.encoder(x, attention_mask=mask)
|
||||
|
||||
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
|
||||
pooled_output = None
|
||||
if attention_mask is None:
|
||||
pooled_output = sequence_output.mean(dim=1)
|
||||
else:
|
||||
attention_mask = attention_mask.to(sequence_output.dtype)
|
||||
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Intermediate output is not yet implemented, use None for placeholder
|
||||
return sequence_output, None, pooled_output
|
||||
|
||||
class XLMRobertaModel(nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.config = XLMRobertaConfig(**config_dict)
|
||||
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.model.embeddings.word_embeddings = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
from comfy import sd1_clip
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from .llama import Qwen25_7BVLI
|
||||
|
||||
|
||||
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Kandinsky5TEModel(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||
|
||||
return cond, l_pooled, extra
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
self.clip_l.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
super().reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Kandinsky5TEModel_
|
||||
|
|
@ -3,11 +3,13 @@ import torch.nn as nn
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
import math
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
import comfy.model_management
|
||||
from . import qwen_vl
|
||||
|
||||
@dataclass
|
||||
|
|
@ -98,28 +100,6 @@ class Qwen3_4BConfig:
|
|||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
vocab_size: int = 152064
|
||||
|
|
@ -175,7 +155,7 @@ class Gemma3_4B_Config:
|
|||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 131072
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta = [1000000.0, 10000.0]
|
||||
rope_theta = [10000.0, 1000000.0]
|
||||
transformer_type: str = "gemma3"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
|
|
@ -184,8 +164,8 @@ class Gemma3_4B_Config:
|
|||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
final_norm: bool = True
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
|
@ -368,7 +348,7 @@ class TransformerBlockGemma2(nn.Module):
|
|||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.sliding_attention is not None:
|
||||
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
|
||||
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
||||
else:
|
||||
self.sliding_attention = False
|
||||
|
|
@ -385,12 +365,7 @@ class TransformerBlockGemma2(nn.Module):
|
|||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
if x.shape[1] > self.sliding_attention:
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask + sliding_mask
|
||||
else:
|
||||
attention_mask = sliding_mask
|
||||
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
|
||||
freqs_cis = freqs_cis[1]
|
||||
else:
|
||||
freqs_cis = freqs_cis[0]
|
||||
|
|
@ -567,15 +542,6 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
|||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Ovis25_2BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
|||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
|
@ -33,11 +33,6 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
|||
|
||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
|
|
@ -45,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
|||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
if model_type == "gemma2_2b":
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
|
|
@ -53,9 +48,9 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
|
|||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.lumina2
|
||||
|
||||
class NewBieTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class NewBieTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = {dtype, dtype_gemma}
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.gemma.set_clip_options(options)
|
||||
self.jina.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.gemma.reset_clip_options()
|
||||
self.jina.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_gemma = token_weight_pairs["gemma"]
|
||||
token_weight_pairs_jina = token_weight_pairs["jina"]
|
||||
|
||||
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
|
||||
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
|
||||
|
||||
return gemma_out, jina_pooled, gemma_extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma.load_sd(sd)
|
||||
else:
|
||||
return self.jina.load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class NewBieTEModel_(NewBieTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return NewBieTEModel_
|
||||
|
|
@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
|
|||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class Omnigen2TEModel_(Omnigen2Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -1,66 +0,0 @@
|
|||
from transformers import Qwen2Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class OvisTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
|
||||
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
class Ovis25_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
|
||||
class OvisTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||
out, pooled = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen3_2b"][0]
|
||||
count_im_start = 0
|
||||
if template_end == -1:
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 4004 and count_im_start < 1:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
|
||||
if out.shape[1] > (template_end + 1):
|
||||
if tok_pairs[template_end + 1][0] == 25:
|
||||
template_end += 1
|
||||
|
||||
out = out[:, template_end:]
|
||||
return out, pooled, {}
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class OvisTEModel_(OvisTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return OvisTEModel_
|
||||
|
|
@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
|||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
|||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -6,15 +6,14 @@ import torch
|
|||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
import comfy.utils
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||
if t5xxl_quantization_metadata is not None:
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
model_options = {**model_options, "model_name": "t5xxl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
|
@ -26,9 +25,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||
if t5_key in state_dict:
|
||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["t5_quantization_metadata"] = quant
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
return out
|
||||
|
||||
|
|
@ -157,11 +156,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
|
|
|||
|
|
@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
|
|||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||
|
||||
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class WanTEModel(WanT5Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
|
|
|||
|
|
@ -34,9 +34,12 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
|
|||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||
class ZImageTEModel_(ZImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ import itertools
|
|||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
|
@ -53,7 +52,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
|
|
@ -803,17 +802,12 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
|
|
@ -1200,70 +1194,3 @@ def unpack_latents(combined_latent, latent_shapes):
|
|||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
def detect_layer_quantization(state_dict, prefix):
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and k.endswith(".comfy_quant"):
|
||||
logging.info("Found quantization metadata version 1")
|
||||
return {"mixed_ops": True}
|
||||
return None
|
||||
|
||||
def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
quant_metadata = None
|
||||
if "_quantization_metadata" not in metadata:
|
||||
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
if scaled_fp8_dtype == torch.float32:
|
||||
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
full_precision_matrix_mult = True
|
||||
else:
|
||||
full_precision_matrix_mult = False
|
||||
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
continue
|
||||
|
||||
out_sd[k_out] = w
|
||||
|
||||
state_dict = out_sd
|
||||
quant_metadata = {"layers": layers}
|
||||
else:
|
||||
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||
|
||||
if quant_metadata is not None:
|
||||
layers = quant_metadata["layers"]
|
||||
for k, v in layers.items():
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
|
|
|||
|
|
@ -5,20 +5,19 @@ This module handles capability negotiation between frontend and backend,
|
|||
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# Default server capabilities
|
||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
}
|
||||
|
||||
|
||||
def get_connection_feature(
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str,
|
||||
default: Any = False
|
||||
|
|
@ -42,7 +41,7 @@ def get_connection_feature(
|
|||
|
||||
|
||||
def supports_feature(
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str
|
||||
) -> bool:
|
||||
|
|
@ -60,7 +59,7 @@ def supports_feature(
|
|||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||
|
||||
|
||||
def get_server_features() -> dict[str, Any]:
|
||||
def get_server_features() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the server's feature flags.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import NamedTuple
|
||||
from typing import Type, List, NamedTuple
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from packaging import version as packaging_version
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
|
|||
|
||||
class ComfyAPIWithVersion(NamedTuple):
|
||||
version: str
|
||||
api_class: type[ComfyAPIBase]
|
||||
api_class: Type[ComfyAPIBase]
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> packaging_version.Version:
|
||||
|
|
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
|
|||
return packaging_version.parse(version_str)
|
||||
|
||||
|
||||
registered_versions: list[ComfyAPIWithVersion] = []
|
||||
registered_versions: List[ComfyAPIWithVersion] = []
|
||||
|
||||
|
||||
def register_versions(versions: list[ComfyAPIWithVersion]):
|
||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||
versions.sort(key=lambda x: parse_version(x.version))
|
||||
global registered_versions
|
||||
registered_versions = versions
|
||||
|
||||
|
||||
def get_all_versions() -> list[ComfyAPIWithVersion]:
|
||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||
"""
|
||||
Returns a list of all registered ComfyAPI versions.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import os
|
|||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, get_origin, get_args, get_type_hints
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
|
|
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
|
|||
return result_container["result"]
|
||||
|
||||
@classmethod
|
||||
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
|
||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a new class with synchronous versions of all async methods.
|
||||
|
||||
|
|
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
|
|||
|
||||
@classmethod
|
||||
def _generate_imports(
|
||||
cls, async_class: type, type_tracker: TypeTracker
|
||||
cls, async_class: Type, type_tracker: TypeTracker
|
||||
) -> list[str]:
|
||||
"""Generate import statements for the stub file."""
|
||||
imports = []
|
||||
|
|
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
|
|||
return imports
|
||||
|
||||
@classmethod
|
||||
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
|
||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
|
|
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
|
|||
def _generate_inner_class_stub(
|
||||
cls,
|
||||
name: str,
|
||||
attr: type,
|
||||
attr: Type,
|
||||
indent: str = " ",
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> list[str]:
|
||||
|
|
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
|
|||
return processed
|
||||
|
||||
@classmethod
|
||||
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
|
||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||
"""
|
||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||
"""
|
||||
|
|
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
|
|||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
|
||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a sync version of an async class
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TypeVar
|
||||
from typing import Type, TypeVar
|
||||
|
||||
class SingletonMetaclass(type):
|
||||
T = TypeVar("T", bound="SingletonMetaclass")
|
||||
|
|
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
|
|||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
def inject_instance(cls: type[T], instance: T) -> None:
|
||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||
assert cls not in SingletonMetaclass._instances, (
|
||||
"Cannot inject instance after first instantiation"
|
||||
)
|
||||
SingletonMetaclass._instances[cls] = instance
|
||||
|
||||
def get_instance(cls: type[T], *args, **kwargs) -> T:
|
||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||
"""
|
||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io as io
|
||||
from . import _ui as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
|
|
@ -79,7 +80,7 @@ class ComfyExtension(ABC):
|
|||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources needed by the extension.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -112,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
|
|||
if TYPE_CHECKING:
|
||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
# create new aliases for io and ui
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from typing import TypedDict, Optional
|
||||
from typing import TypedDict, List, Optional
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
|
|
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
|
|||
Optional noise mask tensor in the same format as samples.
|
||||
"""
|
||||
|
||||
batch_index: Optional[list[int]]
|
||||
batch_index: Optional[List[int]]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from fractions import Fraction
|
|||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,14 +3,14 @@ from av.container import InputContainer
|
|||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ import copy
|
|||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
|
||||
from typing_extensions import NotRequired, final
|
||||
|
|
@ -26,9 +25,11 @@ if TYPE_CHECKING:
|
|||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
||||
|
||||
class FolderType(str, Enum):
|
||||
input = "input"
|
||||
|
|
@ -75,6 +76,16 @@ class NumberDisplay(str, Enum):
|
|||
slider = "slider"
|
||||
|
||||
|
||||
class _StringIOType(str):
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
class _ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
|
|
@ -114,7 +125,8 @@ def comfytype(io_type: str, **kwargs):
|
|||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
# assign ComfyType attributes, if needed
|
||||
new_cls.io_type = io_type
|
||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
||||
new_cls.io_type = _StringIOType(io_type)
|
||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||
new_cls.Input.Parent = new_cls
|
||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||
|
|
@ -138,9 +150,6 @@ class _IO_V3:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def io_type(self):
|
||||
return self.Parent.io_type
|
||||
|
|
@ -153,7 +162,7 @@ class Input(_IO_V3):
|
|||
'''
|
||||
Base class for a V3 Input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__()
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
|
|
@ -161,7 +170,6 @@ class Input(_IO_V3):
|
|||
self.tooltip = tooltip
|
||||
self.lazy = lazy
|
||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||
self.rawLink = raw_link
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
|
|
@ -169,14 +177,10 @@ class Input(_IO_V3):
|
|||
"optional": self.optional,
|
||||
"tooltip": self.tooltip,
|
||||
"lazy": self.lazy,
|
||||
"rawLink": self.rawLink,
|
||||
}) | prune_dict(self.extra_dict)
|
||||
|
||||
def get_io_type(self):
|
||||
return self.io_type
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self]
|
||||
return _StringIOType(self.io_type)
|
||||
|
||||
class WidgetInput(Input):
|
||||
'''
|
||||
|
|
@ -184,8 +188,8 @@ class WidgetInput(Input):
|
|||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: Any=None,
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.default = default
|
||||
self.socketless = socketless
|
||||
self.widget_type = widget_type
|
||||
|
|
@ -207,14 +211,13 @@ class Output(_IO_V3):
|
|||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
self.id = id
|
||||
self.display_name = display_name if display_name else id
|
||||
self.display_name = display_name
|
||||
self.tooltip = tooltip
|
||||
self.is_output_list = is_output_list
|
||||
|
||||
def as_dict(self):
|
||||
display_name = self.display_name if self.display_name else self.id
|
||||
return prune_dict({
|
||||
"display_name": display_name,
|
||||
"display_name": self.display_name,
|
||||
"tooltip": self.tooltip,
|
||||
"is_output_list": self.is_output_list,
|
||||
})
|
||||
|
|
@ -242,8 +245,8 @@ class Boolean(ComfyTypeIO):
|
|||
'''Boolean input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: bool=None, label_on: str=None, label_off: str=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
self.label_on = label_on
|
||||
self.label_off = label_off
|
||||
self.default: bool
|
||||
|
|
@ -262,8 +265,8 @@ class Int(ComfyTypeIO):
|
|||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
|
|
@ -288,8 +291,8 @@ class Float(ComfyTypeIO):
|
|||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
|
|
@ -314,8 +317,8 @@ class String(ComfyTypeIO):
|
|||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
|
|
@ -348,14 +351,12 @@ class Combo(ComfyTypeIO):
|
|||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
socketless: bool=None,
|
||||
extra_dict=None,
|
||||
raw_link: bool=None,
|
||||
):
|
||||
if isinstance(options, type) and issubclass(options, Enum):
|
||||
options = [v.value for v in options]
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
self.multiselect = False
|
||||
self.options = options
|
||||
self.control_after_generate = control_after_generate
|
||||
|
|
@ -379,6 +380,10 @@ class Combo(ComfyTypeIO):
|
|||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.options = options if options is not None else []
|
||||
|
||||
@property
|
||||
def io_type(self):
|
||||
return self.options
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
class MultiCombo(ComfyTypeI):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
|
|
@ -387,8 +392,8 @@ class MultiCombo(ComfyTypeI):
|
|||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||
socketless: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
||||
self.multiselect = True
|
||||
self.placeholder = placeholder
|
||||
self.chip = chip
|
||||
|
|
@ -421,9 +426,9 @@ class Webcam(ComfyTypeIO):
|
|||
Type = str
|
||||
def __init__(
|
||||
self, id: str, display_name: str=None, optional=False,
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
|
||||
|
||||
@comfytype(io_type="MASK")
|
||||
|
|
@ -556,8 +561,6 @@ class Conditioning(ComfyTypeIO):
|
|||
'''Used by WAN Camera.'''
|
||||
time_dim_concat: NotRequired[torch.Tensor]
|
||||
'''Used by WAN Phantom Subject.'''
|
||||
time_dim_replace: NotRequired[torch.Tensor]
|
||||
'''Used by Kandinsky5 I2V.'''
|
||||
|
||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||
Type = CondList
|
||||
|
|
@ -644,7 +647,7 @@ class Video(ComfyTypeIO):
|
|||
|
||||
@comfytype(io_type="SVG")
|
||||
class SVG(ComfyTypeIO):
|
||||
Type = _SVG
|
||||
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
|
||||
|
||||
@comfytype(io_type="LORA_MODEL")
|
||||
class LoraModel(ComfyTypeIO):
|
||||
|
|
@ -762,13 +765,6 @@ class AudioEncoder(ComfyTypeIO):
|
|||
class AudioEncoderOutput(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="TRACKS")
|
||||
class Tracks(ComfyTypeIO):
|
||||
class TrackDict(TypedDict):
|
||||
track_path: torch.Tensor
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
|
|
@ -776,7 +772,7 @@ class MultiType:
|
|||
'''
|
||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||
'''
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
# if id is an Input, then use that Input with overridden values
|
||||
self.input_override = None
|
||||
if isinstance(id, Input):
|
||||
|
|
@ -789,7 +785,7 @@ class MultiType:
|
|||
# if is a widget input, make sure widget_type is set appropriately
|
||||
if isinstance(self.input_override, WidgetInput):
|
||||
self.input_override.widget_type = self.input_override.get_io_type()
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self._io_types = types
|
||||
|
||||
@property
|
||||
|
|
@ -818,326 +814,115 @@ class MultiType:
|
|||
else:
|
||||
return super().as_dict()
|
||||
|
||||
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
||||
class MatchType(ComfyTypeIO):
|
||||
class Template:
|
||||
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
|
||||
self.template_id = template_id
|
||||
# account for syntactic sugar
|
||||
if not isinstance(allowed_types, Iterable):
|
||||
allowed_types = [allowed_types]
|
||||
for t in allowed_types:
|
||||
if not isinstance(t, type):
|
||||
if not isinstance(t, _ComfyType):
|
||||
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
|
||||
else:
|
||||
if not issubclass(t, _ComfyType):
|
||||
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
|
||||
self.allowed_types = allowed_types
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"template_id": self.template_id,
|
||||
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
|
||||
}
|
||||
|
||||
class Input(Input):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
class Output(Output):
|
||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
if not id and not display_name:
|
||||
display_name = "MATCHTYPE"
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
class DynamicInput(Input, ABC):
|
||||
'''
|
||||
Abstract class for dynamic input registration.
|
||||
'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
...
|
||||
|
||||
class DynamicOutput(Output, ABC):
|
||||
'''
|
||||
Abstract class for dynamic output registration.
|
||||
'''
|
||||
pass
|
||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
|
||||
@abstractmethod
|
||||
def get_dynamic(self) -> list[Output]:
|
||||
...
|
||||
|
||||
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
|
||||
if prefix_list is None:
|
||||
prefix_list = []
|
||||
if id is not None:
|
||||
prefix_list = prefix_list + [id]
|
||||
return prefix_list
|
||||
|
||||
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
|
||||
assert not (prefix_list is None and id is None)
|
||||
if prefix_list is None:
|
||||
return id
|
||||
elif id is not None:
|
||||
prefix_list = prefix_list + [id]
|
||||
return ".".join(prefix_list)
|
||||
|
||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||
class Autogrow(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
_MaxNames = 100 # NOTE: max 100 names for sanity
|
||||
|
||||
class _AutogrowTemplate:
|
||||
def __init__(self, input: Input):
|
||||
# dynamic inputs are not allowed as the template input
|
||||
assert(not isinstance(input, DynamicInput))
|
||||
self.input = copy.copy(input)
|
||||
if isinstance(self.input, WidgetInput):
|
||||
self.input.force_input = True
|
||||
self.names: list[str] = []
|
||||
self.cached_inputs = {}
|
||||
|
||||
def _create_input(self, input: Input, name: str):
|
||||
new_input = copy.copy(self.input)
|
||||
new_input.id = name
|
||||
return new_input
|
||||
|
||||
def _create_cached_inputs(self):
|
||||
for name in self.names:
|
||||
self.cached_inputs[name] = self._create_input(self.input, name)
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return list(self.cached_inputs.values())
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
"input": create_input_dict_v1([self.input]),
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
self.input.validate()
|
||||
|
||||
class TemplatePrefix(_AutogrowTemplate):
|
||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||
super().__init__(input)
|
||||
self.prefix = prefix
|
||||
assert(min >= 0)
|
||||
assert(max >= 1)
|
||||
assert(max <= Autogrow._MaxNames)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
|
||||
self._create_cached_inputs()
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"prefix": self.prefix,
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
})
|
||||
|
||||
class TemplateNames(_AutogrowTemplate):
|
||||
def __init__(self, input: Input, names: list[str], min: int=1):
|
||||
super().__init__(input)
|
||||
self.names = names[:Autogrow._MaxNames]
|
||||
assert(min >= 0)
|
||||
self.min = min
|
||||
self._create_cached_inputs()
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"names": self.names,
|
||||
"min": self.min,
|
||||
})
|
||||
|
||||
class AutogrowDynamic(ComfyTypeI):
|
||||
Type = list[Any]
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
|
||||
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.template_input = template_input
|
||||
if min is not None:
|
||||
assert(min >= 1)
|
||||
if max is not None:
|
||||
assert(max >= 1)
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
curr_count = 1
|
||||
new_inputs = []
|
||||
for i in range(self.min):
|
||||
new_input = copy.copy(self.template_input)
|
||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
||||
if new_input.display_name is not None:
|
||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
||||
new_input.optional = self.optional or new_input.optional
|
||||
if isinstance(self.template_input, WidgetInput):
|
||||
new_input.force_input = True
|
||||
new_inputs.append(new_input)
|
||||
curr_count += 1
|
||||
# pretend to expand up to max
|
||||
for i in range(curr_count-1, self.max):
|
||||
new_input = copy.copy(self.template_input)
|
||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
||||
if new_input.display_name is not None:
|
||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
||||
new_input.optional = True
|
||||
if isinstance(self.template_input, WidgetInput):
|
||||
new_input.force_input = True
|
||||
new_inputs.append(new_input)
|
||||
curr_count += 1
|
||||
return new_inputs
|
||||
|
||||
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
|
||||
class ComboDynamic(ComfyTypeI):
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str):
|
||||
pass
|
||||
|
||||
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
||||
class MatchType(ComfyTypeIO):
|
||||
class Template:
|
||||
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
|
||||
self.template_id = template_id
|
||||
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"template_id": self.template_id,
|
||||
"allowed_types": "".join(t.io_type for t in self.allowed_types),
|
||||
}
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.template = template
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [self]
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + self.template.get_all()
|
||||
class Output(DynamicOutput):
|
||||
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
|
||||
def validate(self):
|
||||
self.template.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
|
||||
# need to figure out names based on template type
|
||||
is_names = ("names" in value[1]["template"])
|
||||
is_prefix = ("prefix" in value[1]["template"])
|
||||
input = value[1]["template"]["input"]
|
||||
if is_names:
|
||||
min = value[1]["template"]["min"]
|
||||
names = value[1]["template"]["names"]
|
||||
max = len(names)
|
||||
elif is_prefix:
|
||||
prefix = value[1]["template"]["prefix"]
|
||||
min = value[1]["template"]["min"]
|
||||
max = value[1]["template"]["max"]
|
||||
names = [f"{prefix}{i}" for i in range(max)]
|
||||
# need to create a new input based on the contents of input
|
||||
template_input = None
|
||||
for _, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input
|
||||
template_input = list(dict_input.values())[0]
|
||||
new_dict = {}
|
||||
for i, name in enumerate(names):
|
||||
expected_id = finalize_prefix(curr_prefix, name)
|
||||
if expected_id in live_inputs:
|
||||
# required
|
||||
if i < min:
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
type_dict[name] = template_input
|
||||
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||
class DynamicCombo(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
|
||||
class Option:
|
||||
def __init__(self, key: str, inputs: list[Input]):
|
||||
self.key = key
|
||||
self.inputs = inputs
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"key": self.key,
|
||||
"inputs": create_input_dict_v1(self.inputs),
|
||||
}
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str, options: list[DynamicCombo.Option],
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.options = options
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [input for option in self.options for input in option.inputs]
|
||||
def get_dynamic(self) -> list[Output]:
|
||||
return [self]
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"options": [o.as_dict() for o in self.options],
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
# make sure all nested inputs are validated
|
||||
for option in self.options:
|
||||
for input in option.inputs:
|
||||
input.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
finalized_id = finalize_prefix(curr_prefix)
|
||||
if finalized_id in live_inputs:
|
||||
key = live_inputs[finalized_id]
|
||||
selected_option = None
|
||||
# get options from dict
|
||||
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
|
||||
for option in options:
|
||||
if option["key"] == key:
|
||||
selected_option = option
|
||||
break
|
||||
if selected_option is not None:
|
||||
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
|
||||
# add self to inputs
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||
class DynamicSlot(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, slot: Input, inputs: list[Input],
|
||||
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
assert(not isinstance(slot, DynamicInput))
|
||||
self.slot = copy.copy(slot)
|
||||
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
|
||||
optional = True
|
||||
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
|
||||
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
|
||||
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
|
||||
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
|
||||
self.inputs = inputs
|
||||
self.force_input = None
|
||||
# force widget inputs to have no widgets, otherwise this would be awkward
|
||||
if isinstance(self.slot, WidgetInput):
|
||||
self.force_input = True
|
||||
self.slot.force_input = True
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self.slot] + self.inputs
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"slotType": str(self.slot.get_io_type()),
|
||||
"inputs": create_input_dict_v1(self.inputs),
|
||||
"forceInput": self.force_input,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
self.slot.validate()
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
finalized_id = finalize_prefix(curr_prefix)
|
||||
if finalized_id in live_inputs:
|
||||
inputs = value[1]["inputs"]
|
||||
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
|
||||
# add self to inputs
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
|
||||
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
|
||||
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||
|
||||
def setup_dynamic_input_funcs():
|
||||
# Autogrow.Input
|
||||
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
|
||||
# DynamicCombo.Input
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
class V3Data(TypedDict):
|
||||
hidden_inputs: dict[str, Any]
|
||||
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||
dynamic_paths: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
|
|
@ -1173,10 +958,6 @@ class HiddenHolder:
|
|||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
|
||||
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||
|
||||
class Hidden(str, Enum):
|
||||
'''
|
||||
Enumerator for requesting hidden variables in nodes.
|
||||
|
|
@ -1203,7 +984,6 @@ class NodeInfoV1:
|
|||
output_is_list: list[bool]=None
|
||||
output_name: list[str]=None
|
||||
output_tooltips: list[str]=None
|
||||
output_matchtypes: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
|
|
@ -1239,9 +1019,9 @@ class Schema:
|
|||
"""Display name of node."""
|
||||
category: str = "sd"
|
||||
"""The category of the node, as per the "Add Node" menu."""
|
||||
inputs: list[Input] = field(default_factory=list)
|
||||
outputs: list[Output] = field(default_factory=list)
|
||||
hidden: list[Hidden] = field(default_factory=list)
|
||||
inputs: list[Input]=None
|
||||
outputs: list[Output]=None
|
||||
hidden: list[Hidden]=None
|
||||
description: str=""
|
||||
"""Node description, shown as a tooltip when hovering over the node."""
|
||||
is_input_list: bool = False
|
||||
|
|
@ -1281,57 +1061,60 @@ class Schema:
|
|||
'''Validate the schema:
|
||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||
'''
|
||||
nested_inputs: list[Input] = []
|
||||
for input in self.inputs:
|
||||
if not isinstance(input, DynamicInput):
|
||||
nested_inputs.extend(input.get_all())
|
||||
input_ids = [i.id for i in nested_inputs]
|
||||
output_ids = [o.id for o in self.outputs]
|
||||
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
|
||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
||||
input_set = set(input_ids)
|
||||
output_set = set(output_ids)
|
||||
issues: list[str] = []
|
||||
issues = []
|
||||
# verify ids are unique per list
|
||||
if len(input_set) != len(input_ids):
|
||||
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
||||
if len(output_set) != len(output_ids):
|
||||
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
||||
# verify ids are unique between lists
|
||||
intersection = input_set & output_set
|
||||
if len(intersection) > 0:
|
||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
||||
if len(issues) > 0:
|
||||
raise ValueError("\n".join(issues))
|
||||
# validate inputs and outputs
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
|
||||
def finalize(self):
|
||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||
# ensure inputs, outputs, and hidden are lists
|
||||
if self.inputs is None:
|
||||
self.inputs = []
|
||||
if self.outputs is None:
|
||||
self.outputs = []
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
# if is an api_node, will need key-related hidden
|
||||
if self.is_api_node:
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
if Hidden.auth_token_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
if Hidden.prompt not in self.hidden:
|
||||
self.hidden.append(Hidden.prompt)
|
||||
if Hidden.extra_pnginfo not in self.hidden:
|
||||
self.hidden.append(Hidden.extra_pnginfo)
|
||||
# give outputs without ids default ids
|
||||
for i, output in enumerate(self.outputs):
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
if self.outputs is not None:
|
||||
for i, output in enumerate(self.outputs):
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
|
||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||
# get V1 inputs
|
||||
input = create_input_dict_v1(self.inputs)
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
if self.inputs:
|
||||
for i in self.inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
dynamic_inputs = i.get_dynamic()
|
||||
for d in dynamic_inputs:
|
||||
add_to_dict_v1(d, input)
|
||||
else:
|
||||
add_to_dict_v1(i, input)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
|
|
@ -1340,24 +1123,12 @@ class Schema:
|
|||
output_is_list = []
|
||||
output_name = []
|
||||
output_tooltips = []
|
||||
output_matchtypes = []
|
||||
any_matchtypes = False
|
||||
if self.outputs:
|
||||
for o in self.outputs:
|
||||
output.append(o.io_type)
|
||||
output_is_list.append(o.is_output_list)
|
||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||
# special handling for MatchType
|
||||
if isinstance(o, MatchType.Output):
|
||||
output_matchtypes.append(o.template.template_id)
|
||||
any_matchtypes = True
|
||||
else:
|
||||
output_matchtypes.append(None)
|
||||
|
||||
# clear out lists that are all None
|
||||
if not any_matchtypes:
|
||||
output_matchtypes = None
|
||||
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
|
|
@ -1366,7 +1137,6 @@ class Schema:
|
|||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
output_tooltips=output_tooltips,
|
||||
output_matchtypes=output_matchtypes,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
category=self.category,
|
||||
|
|
@ -1411,84 +1181,17 @@ class Schema:
|
|||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
hidden = d.pop("hidden", None)
|
||||
parse_class_inputs(out_dict, live_inputs, d)
|
||||
if hidden is not None and include_hidden:
|
||||
out_dict["hidden"] = hidden
|
||||
v3_data = {}
|
||||
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||
if dynamic_paths is not None:
|
||||
v3_data["dynamic_paths"] = dynamic_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
for input_type, inner_d in curr_dict.items():
|
||||
for id, value in inner_d.items():
|
||||
io_type = value[0]
|
||||
if io_type in DYNAMIC_INPUT_LOOKUP:
|
||||
# dynamic inputs need to be handled with lookup functions
|
||||
dynamic_input_func = get_dynamic_input_func(io_type)
|
||||
new_prefix = handle_prefix(curr_prefix, id)
|
||||
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
|
||||
else:
|
||||
# non-dynamic inputs get directly transferred
|
||||
finalized_id = finalize_prefix(curr_prefix, id)
|
||||
out_dict[input_type][finalized_id] = value
|
||||
if curr_prefix:
|
||||
out_dict["dynamic_paths"][finalized_id] = finalized_id
|
||||
|
||||
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
for i in inputs:
|
||||
add_to_dict_v1(i, input)
|
||||
return input
|
||||
|
||||
def add_to_dict_v1(i: Input, d: dict):
|
||||
def add_to_dict_v1(i: Input, input: dict):
|
||||
key = "optional" if i.optional else "required"
|
||||
as_dict = i.as_dict()
|
||||
# for v1, we don't want to include the optional key
|
||||
as_dict.pop("optional", None)
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
result = {}
|
||||
|
||||
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||
|
||||
for key, path in paths.items():
|
||||
parts = path.split(".")
|
||||
current = result
|
||||
|
||||
for i, p in enumerate(parts):
|
||||
is_last = (i == len(parts) - 1)
|
||||
|
||||
if is_last:
|
||||
value = values.pop(key, None)
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
else:
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
return values
|
||||
|
||||
|
||||
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
|
|
@ -1498,6 +1201,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||
SCHEMA = None
|
||||
|
||||
# filled in during execution
|
||||
resources: Resources = None
|
||||
hidden: HiddenHolder = None
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1544,6 +1248,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||
return [name for name in kwargs if kwargs[name] is None]
|
||||
|
||||
def __init__(self):
|
||||
self.local_resources: ResourcesLocal = None
|
||||
self.__class__.VALIDATE_CLASS()
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1606,12 +1311,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||
|
||||
@final
|
||||
@classmethod
|
||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
|
||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
|
||||
"""Creates clone of real node class to prevent monkey-patching."""
|
||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
|
|
@ -1728,10 +1433,15 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
|
||||
schema = cls.FINALIZE_SCHEMA()
|
||||
info = schema.get_v1_info(cls)
|
||||
return info.input
|
||||
input = info.input
|
||||
if not include_hidden:
|
||||
input.pop("hidden", None)
|
||||
if return_schema:
|
||||
return input, schema
|
||||
return input
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
|
|
@ -1803,7 +1513,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
|
|||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, **kwargs) -> bool | str:
|
||||
def validate_inputs(cls, **kwargs) -> bool:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -1850,7 +1560,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||
return self.args if len(self.args) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
|
||||
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
|
||||
args = ()
|
||||
ui = None
|
||||
expand = None
|
||||
|
|
@ -1863,7 +1573,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||
ui = data["ui"]
|
||||
if "expand" in data:
|
||||
expand = data["expand"]
|
||||
return cls(*args, ui=ui, expand=expand)
|
||||
return cls(args=args, ui=ui, expand=expand)
|
||||
|
||||
def __getitem__(self, index) -> Any:
|
||||
return self.args[index]
|
||||
|
|
@ -1918,7 +1628,6 @@ __all__ = [
|
|||
"StyleModel",
|
||||
"Gligen",
|
||||
"UpscaleModel",
|
||||
"LatentUpscaleModel",
|
||||
"Audio",
|
||||
"Video",
|
||||
"SVG",
|
||||
|
|
@ -1942,11 +1651,6 @@ __all__ = [
|
|||
"SEGS",
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"Autogrow",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
|
@ -1957,5 +1661,4 @@ __all__ = [
|
|||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
from ._io import * # noqa: F403
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
||||
|
|
@ -3,8 +3,8 @@ from __future__ import annotations
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
|
@ -21,7 +21,7 @@ import folder_paths
|
|||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from ._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
|
|
@ -82,7 +82,7 @@ class ImageSaveHelper:
|
|||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
|
|
@ -95,7 +95,7 @@ class ImageSaveHelper:
|
|||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
|
|
@ -120,7 +120,7 @@ class ImageSaveHelper:
|
|||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
|
|
@ -136,7 +136,7 @@ class ImageSaveHelper:
|
|||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
|
|
@ -154,7 +154,7 @@ class ImageSaveHelper:
|
|||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
|
|
@ -168,7 +168,7 @@ class ImageSaveHelper:
|
|||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
|
|
@ -190,7 +190,7 @@ class ImageSaveHelper:
|
|||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
|
|
@ -208,7 +208,7 @@ class ImageSaveHelper:
|
|||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: type[ComfyNode] | None,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
|
|
@ -237,7 +237,7 @@ class ImageSaveHelper:
|
|||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: type[ComfyNode] | None,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
|
|
@ -266,7 +266,7 @@ class AudioSaveHelper:
|
|||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: type[ComfyNode] | None,
|
||||
cls: Type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
|
|
@ -318,10 +318,9 @@ class AudioSaveHelper:
|
|||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
layout = "mono" if waveform.shape[0] == 1 else "stereo"
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
|
|
@ -333,7 +332,7 @@ class AudioSaveHelper:
|
|||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
|
|
@ -342,12 +341,12 @@ class AudioSaveHelper:
|
|||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout=layout,
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
|
@ -371,7 +370,7 @@ class AudioSaveHelper:
|
|||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
|
|
@ -387,7 +386,7 @@ class AudioSaveHelper:
|
|||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
|
|
@ -411,7 +410,7 @@ class PreviewMask(PreviewImage):
|
|||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
|
|
@ -437,19 +436,9 @@ class PreviewUI3D(_UIOutput):
|
|||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
self.bg_image_path = None
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
if bg_image is not None:
|
||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||
img = PILImage.fromarray(img_array)
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||
bg_image_path = os.path.join(temp_dir, filename)
|
||||
img.save(bg_image_path, compress_level=1)
|
||||
self.bg_image_path = f"temp/{filename}"
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
from ._ui import * # noqa: F403
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
|
|
@ -9,5 +8,4 @@ __all__ = [
|
|||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"SVG",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
from io import BytesIO
|
||||
|
||||
|
||||
class SVG:
|
||||
"""Stores SVG representations via a list of BytesIO objects."""
|
||||
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from .._input import ImageInput, AudioInput
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from comfy_api.latest import (
|
|||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
|
|
@ -42,8 +42,4 @@ __all__ = [
|
|||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ from comfy_api.latest import ComfyAPI_latest
|
|||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from typing import List, Type
|
||||
|
||||
supported_versions: list[type[ComfyAPIBase]] = [
|
||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||
ComfyAPI_latest,
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
ComfyAPIAdapter_v0_0_1,
|
||||
|
|
|
|||
|
|
@ -1,144 +0,0 @@
|
|||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
size: str | None = Field(None)
|
||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(False)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(False)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: list[str] | None = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
"seedance-1-0-pro-fast-251015": {
|
||||
"480p": 50,
|
||||
"720p": 65,
|
||||
"1080p": 100,
|
||||
},
|
||||
}
|
||||
|
|
@ -84,7 +84,15 @@ class GeminiSystemInstructionContent(BaseModel):
|
|||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
|
|
@ -133,7 +141,6 @@ class GeminiImageGenerateContentRequest(BaseModel):
|
|||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
uploadImagesToStorage: bool = Field(True)
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModel):
|
||||
|
|
|
|||
|
|
@ -46,68 +46,21 @@ class TaskStatusVideoResult(BaseModel):
|
|||
url: str | None = Field(None, description="URL for generated video")
|
||||
|
||||
|
||||
class TaskStatusImageResult(BaseModel):
|
||||
index: int = Field(..., description="Image Number,0-9")
|
||||
url: str = Field(..., description="URL for generated image")
|
||||
|
||||
|
||||
class TaskStatusResults(BaseModel):
|
||||
class TaskStatusVideoResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
images: list[TaskStatusImageResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponseData(BaseModel):
|
||||
class TaskStatusVideoResponseData(BaseModel):
|
||||
created_at: int | None = Field(None, description="Task creation time")
|
||||
updated_at: int | None = Field(None, description="Task update time")
|
||||
task_status: str | None = None
|
||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||
task_id: str | None = Field(None, description="Task ID")
|
||||
task_result: TaskStatusResults | None = Field(None)
|
||||
task_result: TaskStatusVideoResults | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
class TaskStatusVideoResponse(BaseModel):
|
||||
code: int | None = Field(None, description="Error code")
|
||||
message: str | None = Field(None, description="Error message")
|
||||
request_id: str | None = Field(None, description="Request ID")
|
||||
data: TaskStatusResponseData | None = Field(None)
|
||||
|
||||
|
||||
class OmniImageParamImage(BaseModel):
|
||||
image: str = Field(...)
|
||||
|
||||
|
||||
class OmniProImageRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-image-o1")
|
||||
resolution: str = Field(..., description="'1k' or '2k'")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
n: int | None = Field(1, le=9)
|
||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||
|
||||
|
||||
class TextToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class ImageToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
image: str = Field(...)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class MotionControlRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
video_url: str = Field(...)
|
||||
keep_original_sound: str = Field(...)
|
||||
character_orientation: str = Field(...)
|
||||
mode: str = Field(..., description="'pro' or 'std'")
|
||||
data: TaskStatusVideoResponseData | None = Field(None)
|
||||
|
|
|
|||
|
|
@ -1,52 +0,0 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Datum2(BaseModel):
|
||||
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
||||
revised_prompt: str | None = Field(None, description="Revised prompt")
|
||||
url: str | None = Field(None, description="URL of the image")
|
||||
|
||||
|
||||
class InputTokensDetails(BaseModel):
|
||||
image_tokens: int | None = None
|
||||
text_tokens: int | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
input_tokens_details: InputTokensDetails | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIImageGenerationResponse(BaseModel):
|
||||
data: list[Datum2] | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class OpenAIImageEditRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str = Field(...)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(None, description="The number of images to generate")
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
size: str | None = Field(None, description="Size of the output image")
|
||||
|
||||
|
||||
class OpenAIImageGenerationRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str | None = Field(None)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(
|
||||
None,
|
||||
description="The number of images to generate.",
|
||||
)
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="The quality of the generated image")
|
||||
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
from typing import Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Pikaffect(str, Enum):
|
||||
Cake_ify = "Cake-ify"
|
||||
Crumble = "Crumble"
|
||||
Crush = "Crush"
|
||||
Decapitate = "Decapitate"
|
||||
Deflate = "Deflate"
|
||||
Dissolve = "Dissolve"
|
||||
Explode = "Explode"
|
||||
Eye_pop = "Eye-pop"
|
||||
Inflate = "Inflate"
|
||||
Levitate = "Levitate"
|
||||
Melt = "Melt"
|
||||
Peel = "Peel"
|
||||
Poke = "Poke"
|
||||
Squish = "Squish"
|
||||
Ta_da = "Ta-da"
|
||||
Tear = "Tear"
|
||||
|
||||
|
||||
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||
duration: Optional[int] = Field(5)
|
||||
ingredientsMode: str = Field(...)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = Field('1080p')
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaGenerateResponse(BaseModel):
|
||||
video_id: str = Field(...)
|
||||
|
||||
|
||||
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(
|
||||
1.7777777777777777,
|
||||
description='Aspect ratio (width / height)',
|
||||
ge=0.4,
|
||||
le=2.5,
|
||||
)
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
pikaffect: Optional[str] = None
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
modifyRegionRoi: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class PikaStatusEnum(str, Enum):
|
||||
queued = "queued"
|
||||
started = "started"
|
||||
finished = "finished"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class PikaVideoResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
status: PikaStatusEnum
|
||||
url: Optional[str] = Field(None)
|
||||
|
|
@ -5,17 +5,11 @@ from typing import Optional, List, Dict, Any, Union
|
|||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v3_0_20250812 = 'v3.0-20250812'
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
|
||||
|
||||
class TripoGeometryQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
|
@ -67,20 +61,14 @@ class TripoSpec(str, Enum):
|
|||
class TripoAnimation(str, Enum):
|
||||
IDLE = "preset:idle"
|
||||
WALK = "preset:walk"
|
||||
RUN = "preset:run"
|
||||
DIVE = "preset:dive"
|
||||
CLIMB = "preset:climb"
|
||||
JUMP = "preset:jump"
|
||||
RUN = "preset:run"
|
||||
SLASH = "preset:slash"
|
||||
SHOOT = "preset:shoot"
|
||||
HURT = "preset:hurt"
|
||||
FALL = "preset:fall"
|
||||
TURN = "preset:turn"
|
||||
QUADRUPED_WALK = "preset:quadruped:walk"
|
||||
HEXAPOD_WALK = "preset:hexapod:walk"
|
||||
OCTOPOD_WALK = "preset:octopod:walk"
|
||||
SERPENTINE_MARCH = "preset:serpentine:march"
|
||||
AQUATIC_MARCH = "preset:aquatic:march"
|
||||
|
||||
class TripoStylizeStyle(str, Enum):
|
||||
LEGO = "lego"
|
||||
|
|
@ -117,11 +105,6 @@ class TripoTaskStatus(str, Enum):
|
|||
BANNED = "banned"
|
||||
EXPIRED = "expired"
|
||||
|
||||
class TripoFbxPreset(str, Enum):
|
||||
BLENDER = "blender"
|
||||
MIXAMO = "mixamo"
|
||||
_3DSMAX = "3dsmax"
|
||||
|
||||
class TripoFileTokenReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
file_token: str
|
||||
|
|
@ -159,7 +142,6 @@ class TripoTextToModelRequest(BaseModel):
|
|||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
style: Optional[TripoStyle] = None
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
|
@ -174,7 +156,6 @@ class TripoImageToModelRequest(BaseModel):
|
|||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
|
|
@ -192,7 +173,6 @@ class TripoMultiviewToModelRequest(BaseModel):
|
|||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||
|
|
@ -239,24 +219,14 @@ class TripoConvertModelRequest(BaseModel):
|
|||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(None, description='The size of the texture')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
|
||||
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
|
||||
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
||||
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
||||
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
||||
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
|
||||
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
||||
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
||||
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
||||
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
||||
|
||||
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
||||
|
||||
class TripoTaskRequest(RootModel):
|
||||
root: Union[
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class Response1(BaseModel):
|
|||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = Field(None)
|
||||
videos: Optional[list[Video]] = None
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from inspect import cleandoc
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
|
|
@ -26,7 +28,7 @@ from comfy_api_nodes.util import (
|
|||
)
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: Input.Image):
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
|
|
@ -36,6 +38,9 @@ def convert_mask_to_image(mask: Input.Image):
|
|||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
|
|
@ -43,7 +48,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||
node_id="FluxProUltraImageNode",
|
||||
display_name="Flux 1.1 [pro] Ultra Image",
|
||||
category="api node/image/BFL",
|
||||
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -112,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||
prompt_upsampling: bool = False,
|
||||
raw: bool = False,
|
||||
seed: int = 0,
|
||||
image_prompt: Input.Image | None = None,
|
||||
image_prompt: torch.Tensor | None = None,
|
||||
image_prompt_strength: float = 0.1,
|
||||
) -> IO.NodeOutput:
|
||||
if image_prompt is None:
|
||||
|
|
@ -150,6 +155,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||
|
||||
|
||||
class FluxKontextProImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
|
|
@ -157,7 +165,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -223,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Input.Image | None = None,
|
||||
input_image: torch.Tensor | None = None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
) -> IO.NodeOutput:
|
||||
|
|
@ -263,14 +271,20 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||
NODE_ID = "FluxKontextMaxImageNode"
|
||||
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
||||
|
||||
|
||||
class FluxProExpandNode(IO.ComfyNode):
|
||||
"""
|
||||
Outpaints image based on prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
|
|
@ -278,7 +292,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||
node_id="FluxProExpandNode",
|
||||
display_name="Flux.1 Expand Image",
|
||||
category="api node/image/BFL",
|
||||
description="Outpaints image based on prompt.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
|
|
@ -357,7 +371,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
top: int,
|
||||
|
|
@ -404,6 +418,9 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||
|
||||
|
||||
class FluxProFillNode(IO.ComfyNode):
|
||||
"""
|
||||
Inpaints image based on mask and prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
|
|
@ -411,7 +428,7 @@ class FluxProFillNode(IO.ComfyNode):
|
|||
node_id="FluxProFillNode",
|
||||
display_name="Flux.1 Fill Image",
|
||||
category="api node/image/BFL",
|
||||
description="Inpaints image based on mask and prompt.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
|
|
@ -463,8 +480,8 @@ class FluxProFillNode(IO.ComfyNode):
|
|||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
steps: int,
|
||||
|
|
@ -508,15 +525,11 @@ class FluxProFillNode(IO.ComfyNode):
|
|||
|
||||
class Flux2ProImageNode(IO.ComfyNode):
|
||||
|
||||
NODE_ID = "Flux2ProImageNode"
|
||||
DISPLAY_NAME = "Flux.2 [pro] Image"
|
||||
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
node_id="Flux2ProImageNode",
|
||||
display_name="Flux.2 [pro] Image",
|
||||
category="api node/image/BFL",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
|
|
@ -550,11 +563,12 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation.",
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
|
||||
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
|
|
@ -573,7 +587,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||
height: int,
|
||||
seed: int,
|
||||
prompt_upsampling: bool,
|
||||
images: Input.Image | None = None,
|
||||
images: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
reference_images = {}
|
||||
if images is not None:
|
||||
|
|
@ -584,7 +598,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
|
||||
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=Flux2ProGenerateRequest(
|
||||
prompt=prompt,
|
||||
|
|
@ -618,13 +632,6 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class Flux2MaxImageNode(Flux2ProImageNode):
|
||||
|
||||
NODE_ID = "Flux2MaxImageNode"
|
||||
DISPLAY_NAME = "Flux.2 [max] Image"
|
||||
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
|
|
@ -635,7 +642,6 @@ class BFLExtension(ComfyExtension):
|
|||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,13 @@
|
|||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_api import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
|
|
@ -43,6 +29,162 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
|||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
|
||||
class Text2ImageModelName(str, Enum):
|
||||
seedream_3 = "seedream-3-0-t2i-250415"
|
||||
|
||||
|
||||
class Image2ImageModelName(str, Enum):
|
||||
seededit_3 = "seededit-3-0-i2i-250628"
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: Text2ImageModelName = Text2ImageModelName.seedream_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
size: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: Image2ImageModelName = Image2ImageModelName.seededit_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: Optional[str] = Field("adaptive")
|
||||
seed: Optional[int] = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field("seedream-4-0-250828")
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: Optional[list[str]] = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
|
||||
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: Optional[TaskStatusError] = Field(None)
|
||||
content: Optional[TaskStatusResult] = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
|
|
@ -52,6 +194,13 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
|||
return response.data[0]["url"]
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "content") and response.content:
|
||||
return response.content.video_url
|
||||
return None
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
|
|
@ -62,7 +211,12 @@ class ByteDanceImageNode(IO.ComfyNode):
|
|||
category="api node/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2ImageModelName,
|
||||
default=Text2ImageModelName.seedream_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
|
|
@ -112,7 +266,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -181,7 +335,12 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||
category="api node/image/ByteDance",
|
||||
description="Edit images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2ImageModelName,
|
||||
default=Image2ImageModelName.seededit_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The base image to edit",
|
||||
|
|
@ -215,7 +374,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -229,14 +388,13 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
|
|
@ -270,13 +428,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNode",
|
||||
display_name="ByteDance Seedream 4.5",
|
||||
display_name="ByteDance Seedream 4",
|
||||
category="api node/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||
options=["seedream-4-0-250828"],
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
|
|
@ -301,7 +459,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=8,
|
||||
step=64,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -310,7 +468,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=8,
|
||||
step=64,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -347,7 +505,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -374,14 +532,14 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: Input.Image | None = None,
|
||||
image: torch.Tensor = None,
|
||||
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
||||
width: int = 2048,
|
||||
height: int = 2048,
|
||||
sequential_image_generation: str = "disabled",
|
||||
max_images: int = 1,
|
||||
seed: int = 0,
|
||||
watermark: bool = False,
|
||||
watermark: bool = True,
|
||||
fail_on_partial: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
|
@ -397,18 +555,6 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||
|
|
@ -461,8 +607,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
|||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
options=Text2VideoModelName,
|
||||
default=Text2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -508,7 +655,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -567,8 +714,9 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
options=Image2VideoModelName,
|
||||
default=Image2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -618,7 +766,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -639,7 +787,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: Input.Image,
|
||||
image: torch.Tensor,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
|
|
@ -685,8 +833,9 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
options=[model.value for model in Image2VideoModelName],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -740,7 +889,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -761,8 +910,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
|
|
@ -819,8 +968,9 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
options=[Image2VideoModelName.seedance_1_lite.value],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -863,7 +1013,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -884,7 +1034,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: Input.Image,
|
||||
images: torch.Tensor,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
|
|
@ -919,8 +1069,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
estimated_duration: Optional[int],
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
|
|
@ -935,7 +1085,7 @@ async def process_video_task(
|
|||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ import torch
|
|||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
|
|
@ -26,15 +27,12 @@ from comfy_api_nodes.apis.gemini_api import (
|
|||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiRole,
|
||||
GeminiSystemInstructionContent,
|
||||
GeminiTextPart,
|
||||
Modality,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
|
|
@ -45,14 +43,6 @@ from comfy_api_nodes.util import (
|
|||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
"format, intent, or abstraction—as literal visual directives for image composition.\n"
|
||||
"If a prompt is conversational or lacks specific visual details, "
|
||||
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
|
||||
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||
)
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
|
|
@ -78,7 +68,7 @@ class GeminiImageModel(str, Enum):
|
|||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: Input.Image,
|
||||
images: torch.Tensor,
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
|
|
@ -142,11 +132,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||
)
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part_type == "text" and part.text:
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
|
@ -166,15 +154,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
|||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
else:
|
||||
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
|
|
@ -292,13 +277,6 @@ class GeminiNode(IO.ComfyNode):
|
|||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
|
|
@ -315,9 +293,7 @@ class GeminiNode(IO.ComfyNode):
|
|||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
|
|
@ -367,11 +343,10 @@ class GeminiNode(IO.ComfyNode):
|
|||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Input.Image | None = None,
|
||||
images: torch.Tensor | None = None,
|
||||
audio: Input.Audio | None = None,
|
||||
video: Input.Video | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
|
|
@ -388,10 +363,7 @@ class GeminiNode(IO.ComfyNode):
|
|||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
# Create response
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
|
|
@ -401,8 +373,7 @@ class GeminiNode(IO.ComfyNode):
|
|||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
],
|
||||
systemInstruction=gemini_system_prompt,
|
||||
]
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
|
|
@ -552,13 +523,6 @@ class GeminiImage(IO.ComfyNode):
|
|||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
|
|
@ -578,11 +542,10 @@ class GeminiImage(IO.ComfyNode):
|
|||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Input.Image | None = None,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
aspect_ratio: str = "auto",
|
||||
response_modalities: str = "IMAGE+TEXT",
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
|
|
@ -596,13 +559,9 @@ class GeminiImage(IO.ComfyNode):
|
|||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
|
|
@ -611,12 +570,11 @@ class GeminiImage(IO.ComfyNode):
|
|||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiImage2(IO.ComfyNode):
|
||||
|
|
@ -682,13 +640,6 @@ class GeminiImage2(IO.ComfyNode):
|
|||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
|
|
@ -711,9 +662,8 @@ class GeminiImage2(IO.ComfyNode):
|
|||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
response_modalities: str,
|
||||
images: Input.Image | None = None,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
||||
|
|
@ -729,13 +679,9 @@ class GeminiImage2(IO.ComfyNode):
|
|||
if aspect_ratio != "auto":
|
||||
image_config.aspectRatio = aspect_ratio
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
|
|
@ -744,12 +690,11 @@ class GeminiImage2(IO.ComfyNode):
|
|||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ For source of truth on the allowed permutations of request fields, please refere
|
|||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
|
@ -50,17 +49,12 @@ from comfy_api_nodes.apis import (
|
|||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.apis.kling_api import (
|
||||
ImageToVideoWithAudioRequest,
|
||||
MotionControlRequest,
|
||||
OmniImageParamImage,
|
||||
OmniParamImage,
|
||||
OmniParamVideo,
|
||||
OmniProFirstLastFrameRequest,
|
||||
OmniProImageRequest,
|
||||
OmniProReferences2VideoRequest,
|
||||
OmniProText2VideoRequest,
|
||||
TaskStatusResponse,
|
||||
TextToVideoWithAudioRequest,
|
||||
TaskStatusVideoResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
|
|
@ -106,6 +100,10 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
|
|||
|
||||
|
||||
MODE_TEXT2VIDEO = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
|
||||
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
||||
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
||||
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
||||
|
|
@ -126,6 +124,8 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
|
|||
|
||||
|
||||
MODE_START_END_FRAME = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
||||
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
||||
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
||||
|
|
@ -210,36 +210,7 @@ VOICES_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
def normalize_omni_prompt_references(prompt: str) -> str:
|
||||
"""
|
||||
Rewrites Kling Omni-style placeholders used in the app, like:
|
||||
|
||||
@image, @image1, @image2, ... @imageN
|
||||
@video, @video1, @video2, ... @videoN
|
||||
|
||||
into the API-compatible form:
|
||||
|
||||
<<<image_1>>>, <<<image_2>>>, ...
|
||||
<<<video_1>>>, <<<video_2>>>, ...
|
||||
|
||||
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
|
||||
def _image_repl(match):
|
||||
return f"<<<image_{match.group('idx') or '1'}>>>"
|
||||
|
||||
def _video_repl(match):
|
||||
return f"<<<video_{match.group('idx') or '1'}>>>"
|
||||
|
||||
# (?<!\w) avoids matching e.g. "test@image.com"
|
||||
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
|
||||
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
|
||||
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||
|
||||
|
||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
|
||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
|
|
@ -247,9 +218,8 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
|
|||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
response_model=TaskStatusVideoResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
max_poll_attempts=160,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
|
@ -480,12 +450,12 @@ async def execute_image2video(
|
|||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
|
|
@ -749,7 +719,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
|||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[8],
|
||||
default=modes[4],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
|
|
@ -807,7 +777,6 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
|
|
@ -827,19 +796,17 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProText2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
|
@ -862,7 +829,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Combo.Input("duration", options=["5", "10"]),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input(
|
||||
"end_frame",
|
||||
|
|
@ -875,7 +842,6 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
|
|
@ -897,16 +863,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
first_frame: Input.Image,
|
||||
end_frame: Input.Image | None = None,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if end_frame is not None and reference_images is not None:
|
||||
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||
if duration not in (5, 10) and end_frame is None and reference_images is None:
|
||||
raise ValueError(
|
||||
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
|
||||
)
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||
image_list: list[OmniParamImage] = [
|
||||
|
|
@ -935,13 +895,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProFirstLastFrameRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
|
@ -970,7 +929,6 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
|
|
@ -991,9 +949,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_images: Input.Image,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if get_number_of_images(reference_images) > 7:
|
||||
raise ValueError("The maximum number of reference images is 7.")
|
||||
|
|
@ -1006,14 +962,13 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
|
@ -1045,7 +1000,6 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
|
|
@ -1068,9 +1022,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
|
||||
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||
|
|
@ -1093,7 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
|
|
@ -1101,7 +1053,6 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||
duration=str(duration),
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
|
@ -1131,7 +1082,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
|
|
@ -1152,9 +1102,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||
video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
|
||||
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||
|
|
@ -1177,7 +1125,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
|
|
@ -1185,96 +1133,11 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||
duration=None,
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageNode",
|
||||
display_name="Kling Omni Image (Pro)",
|
||||
category="api node/image/Kling",
|
||||
description="Create or edit images with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-image-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the image content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 10 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
image_list: list[OmniImageParamImage] = []
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 10:
|
||||
raise ValueError("The maximum number of reference images is 10.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniImageParamImage(image=i))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProImageRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
aspect_ratio=aspect_ratio,
|
||||
image_list=image_list if image_list else None,
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||
|
||||
|
||||
class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
"""
|
||||
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
||||
|
|
@ -1344,8 +1207,9 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
|||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingImage2VideoNode",
|
||||
display_name="Kling Image(First Frame) to Video",
|
||||
display_name="Kling Image to Video",
|
||||
category="api node/video/Kling",
|
||||
description="Kling Image to Video Node",
|
||||
inputs=[
|
||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
|
|
@ -1503,7 +1367,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
|||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[6],
|
||||
default=modes[8],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
|
|
@ -1966,7 +1830,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||
IO.Combo.Input(
|
||||
"model_name",
|
||||
options=[i.value for i in KlingImageGenModelName],
|
||||
default="kling-v2",
|
||||
default="kling-v1",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
|
|
@ -2049,221 +1913,6 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||
return IO.NodeOutput(await image_result_to_node_output(images))
|
||||
|
||||
|
||||
class TextToVideoWithAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoWithAudio",
|
||||
display_name="Kling Text to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
IO.Combo.Input("mode", options=["pro"]),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
generate_audio: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=TextToVideoWithAudioRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
mode=mode,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingImageToVideoWithAudio",
|
||||
display_name="Kling Image(First Frame) to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.Image.Input("start_frame"),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
IO.Combo.Input("mode", options=["pro"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
start_frame: Input.Image,
|
||||
prompt: str,
|
||||
mode: str,
|
||||
duration: int,
|
||||
generate_audio: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=ImageToVideoWithAudioRequest(
|
||||
model_name=model_name,
|
||||
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
|
||||
prompt=prompt,
|
||||
mode=mode,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class MotionControl(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingMotionControl",
|
||||
display_name="Kling Motion Control",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.Image.Input("reference_image"),
|
||||
IO.Video.Input(
|
||||
"reference_video",
|
||||
tooltip="Motion reference video used to drive movement/expression.\n"
|
||||
"Duration limits depend on character_orientation:\n"
|
||||
" - image: 3–10s (max 10s)\n"
|
||||
" - video: 3–30s (max 30s)",
|
||||
),
|
||||
IO.Boolean.Input("keep_original_sound", default=True),
|
||||
IO.Combo.Input(
|
||||
"character_orientation",
|
||||
options=["video", "image"],
|
||||
tooltip="Controls where the character's facing/orientation comes from.\n"
|
||||
"video: movements, expressions, camera moves, and orientation "
|
||||
"follow the motion reference video (other details via prompt).\n"
|
||||
"image: movements and expressions still follow the motion reference video, "
|
||||
"but the character orientation matches the reference image (camera/other details via prompt).",
|
||||
),
|
||||
IO.Combo.Input("mode", options=["pro", "std"]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
reference_image: Input.Image,
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
character_orientation: str,
|
||||
mode: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2500)
|
||||
validate_image_dimensions(reference_image, min_width=340, min_height=340)
|
||||
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
|
||||
if character_orientation == "image":
|
||||
validate_video_duration(reference_video, min_duration=3, max_duration=10)
|
||||
else:
|
||||
validate_video_duration(reference_video, min_duration=3, max_duration=30)
|
||||
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=MotionControlRequest(
|
||||
prompt=prompt,
|
||||
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
|
||||
video_url=await upload_video_to_comfyapi(cls, reference_video),
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
character_orientation=character_orientation,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class KlingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
|
|
@ -2286,10 +1935,6 @@ class KlingExtension(ComfyExtension):
|
|||
OmniProImageToVideoNode,
|
||||
OmniProVideoToVideoNode,
|
||||
OmniProEditVideoNode,
|
||||
OmniProImageNode,
|
||||
TextToVideoWithAudio,
|
||||
ImageToVideoWithAudio,
|
||||
MotionControl,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
|
|
@ -23,9 +26,9 @@ class ExecuteTaskRequest(BaseModel):
|
|||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: int | None = Field(25)
|
||||
generate_audio: bool | None = Field(True)
|
||||
image_uri: str | None = Field(None)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
|
|
@ -100,7 +103,7 @@ class TextToVideoNode(IO.ComfyNode):
|
|||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
|
|
@ -150,7 +153,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
image: torch.Tensor,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
|
|
@ -180,7 +183,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
|
|
@ -58,7 +61,7 @@ def validate_task_creation_response(response) -> None:
|
|||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
|
|
@ -79,7 +82,7 @@ def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
|||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
|
|
@ -103,7 +106,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
|||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
|
|
@ -116,7 +119,7 @@ def _validate_minimum_duration(duration: float) -> None:
|
|||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
|
|
@ -238,7 +241,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
|||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
|
|
@ -359,9 +362,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
|||
prompt: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
video: Input.Video | None = None,
|
||||
video: Optional[VideoInput] = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: int | None = 100,
|
||||
motion_intensity: Optional[int] = 100,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
|
|
|
|||
|
|
@ -1,45 +1,46 @@
|
|||
import base64
|
||||
from io import BytesIO
|
||||
import os
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import folder_paths
|
||||
import base64
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
CreateModelResponseProperties,
|
||||
Detail,
|
||||
InputContent,
|
||||
InputFileContent,
|
||||
InputImageContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputTextContent,
|
||||
Item,
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
OpenAICreateResponse,
|
||||
OpenAIResponse,
|
||||
CreateModelResponseProperties,
|
||||
Item,
|
||||
OutputContent,
|
||||
InputImageContent,
|
||||
Detail,
|
||||
InputTextContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputContent,
|
||||
InputFileContent,
|
||||
)
|
||||
from comfy_api_nodes.apis.openai_api import (
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
downscale_image_tensor,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
download_url_to_bytesio,
|
||||
validate_string,
|
||||
tensor_to_base64_string,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
text_filepath_to_data_uri,
|
||||
)
|
||||
|
||||
|
||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||
|
||||
|
|
@ -97,6 +98,9 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
|||
|
||||
|
||||
class OpenAIDalle2(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
|
@ -104,7 +108,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||
node_id="OpenAIDalle2",
|
||||
display_name="OpenAI DALL·E 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -230,6 +234,9 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||
|
||||
|
||||
class OpenAIDalle3(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
|
@ -237,7 +244,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||
node_id="OpenAIDalle3",
|
||||
display_name="OpenAI DALL·E 3",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
|
|
@ -319,16 +326,10 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
# https://platform.openai.com/docs/pricing
|
||||
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
|
||||
|
||||
|
||||
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||
|
||||
|
||||
class OpenAIGPTImage1(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
|
@ -336,13 +337,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for GPT Image",
|
||||
tooltip="Text prompt for GPT Image 1",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
|
|
@ -364,8 +365,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque", "transparent"],
|
||||
default="opaque",
|
||||
options=["opaque", "transparent"],
|
||||
tooltip="Return image with or without background",
|
||||
optional=True,
|
||||
),
|
||||
|
|
@ -396,11 +397,6 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5"],
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
|
|
@ -416,34 +412,32 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
seed: int = 0,
|
||||
quality: str = "low",
|
||||
background: str = "opaque",
|
||||
image: Input.Image | None = None,
|
||||
mask: Input.Image | None = None,
|
||||
n: int = 1,
|
||||
size: str = "1024x1024",
|
||||
model: str = "gpt-image-1",
|
||||
prompt,
|
||||
seed=0,
|
||||
quality="low",
|
||||
background="opaque",
|
||||
image=None,
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
if mask is not None and image is None:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if model == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
files = []
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
content_type = "multipart/form-data"
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
single_image = image[i: i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
|
||||
single_image = image[i : i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
|
|
@ -456,59 +450,44 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||
else:
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
_, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
raise Exception("Cannot use a mask without an input image")
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
batch, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
# Build the operation
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=path, method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
),
|
||||
files=files if files else None,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageEditRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
files=files,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
else:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,568 @@
|
|||
"""
|
||||
Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||
|
||||
PIKA_API_VERSION = "2.2"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||
|
||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
async def execute_task(
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
video_url = final_response.url
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||
|
||||
|
||||
def get_base_inputs_types() -> list[IO.Input]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return [
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||
]
|
||||
|
||||
|
||||
class PikaImageToVideo(IO.ComfyNode):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaImageToVideoNode2_2",
|
||||
display_name="Pika Image to Video",
|
||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaTextToVideoNode2_2",
|
||||
display_name="Pika Text to Video",
|
||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaScenesV2_2",
|
||||
display_name="Pika Scenes (Video Image Composition)",
|
||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
IO.Combo.Input(
|
||||
"ingredients_mode",
|
||||
options=["creative", "precise"],
|
||||
default="creative",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_1",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_2",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_3",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_4",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_5",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
image_ingredient_2,
|
||||
image_ingredient_3,
|
||||
image_ingredient_4,
|
||||
image_ingredient_5,
|
||||
]:
|
||||
if image is not None:
|
||||
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||
|
||||
pika_files = [
|
||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||
]
|
||||
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
ingredientsMode=ingredients_mode,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikadditions",
|
||||
display_name="Pikadditions (Video Object Insertion)",
|
||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaswaps",
|
||||
display_name="Pika Swaps (Video Object Replacement)",
|
||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The image used to replace the masked object in the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Use the mask to define areas in the video to replace.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||
IO.String.Input(
|
||||
"region_to_modify",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Plaintext description of the object / region to modify.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
prompt_text: str = "",
|
||||
negative_prompt: str = "",
|
||||
seed: int = 0,
|
||||
region_to_modify: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
}
|
||||
if mask is not None:
|
||||
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||
if image is not None:
|
||||
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaffects",
|
||||
display_name="Pikaffects (Video Effects)",
|
||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
IO.Combo.Input(
|
||||
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaStartEndFrameNode2_2",
|
||||
display_name="Pika Start and End Frame to Video",
|
||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||
pika_files = [
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PikaImageToVideo,
|
||||
PikaTextToVideoNode,
|
||||
PikaScenes,
|
||||
PikAdditionsNode,
|
||||
PikaSwapsNode,
|
||||
PikaffectsNode,
|
||||
PikaStartEndFrameNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PikaApiNodesExtension:
|
||||
return PikaApiNodesExtension()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue