Compare commits

...

9 Commits

4 changed files with 82 additions and 3 deletions

View File

@ -4,6 +4,7 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
@ -880,14 +881,24 @@ class ComboDynamic(ComfyTypeI):
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
self.template_id = template_id
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
# account for syntactic sugar
if not isinstance(allowed_types, Iterable):
allowed_types = [allowed_types]
for t in allowed_types:
if not isinstance(t, type):
if not isinstance(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
else:
if not issubclass(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
self.allowed_types = allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": "".join(t.io_type for t in self.allowed_types),
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
}
class Input(DynamicInput):
@ -979,6 +990,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@ -1118,12 +1130,24 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
output_matchtypes = []
any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
# special handling for MatchType
if isinstance(o, MatchType.Output):
output_matchtypes.append(o.template.template_id)
any_matchtypes = True
else:
output_matchtypes.append(None)
# clear out lists that are all None
if not any_matchtypes:
output_matchtypes = None
info = NodeInfoV1(
input=input,
@ -1132,6 +1156,7 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@ -1646,6 +1671,8 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
# Dynamic Types
"MatchType",
# Other classes
"HiddenHolder",
"Hidden",

View File

@ -1,4 +1,5 @@
from __future__ import annotations
from comfy_api.latest import IO
def validate_node_input(
@ -23,6 +24,11 @@ def validate_node_input(
if not received_type != input_type:
return True
# If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
return True
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False

View File

@ -0,0 +1,45 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True),
io.MatchType.Input("on_true", template=template, lazy=True),
],
outputs=[
io.MatchType.Output("output", template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=None, on_true=None):
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
return io.NodeOutput(on_true if switch else on_false)
class LogicExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SwitchNode,
]
async def comfy_entrypoint() -> LogicExtension:
return LogicExtension()

View File

@ -2330,6 +2330,7 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
"nodes_logic.py",
]
import_failed = []