Compare commits

...

91 Commits

Author SHA1 Message Date
Alexander Piskun 917177e821
move assets related stuff to "app/assets" folder (#10184) 2025-10-03 11:53:26 -07:00
Alexander Piskun fd6ac0a765
drop PgSQL 14, unite migration for SQLite and PgSQL (#10165) 2025-10-03 11:34:06 -07:00
Alexander Piskun 94941c50b3
move alembic_db inside app folder (#10163) 2025-10-02 15:01:16 -07:00
Jedrzej Kosinski fbba2e59e5 Satisfy ruff 2025-09-26 20:39:23 -07:00
Jedrzej Kosinski adccfb2dfd Remove populate_db_with_asset from load_torch_file for now, as nothing yet uses the hashes 2025-09-26 20:33:46 -07:00
Jedrzej Kosinski 9f4c0f3afe Merge branch 'master' into asset-management 2025-09-26 20:24:25 -07:00
Jedrzej Kosinski ca39552954 Merge branch 'master' into asset-management 2025-09-24 23:44:57 -07:00
Jedrzej Kosinski 4dd843d36f Merge branch 'master' into asset-management 2025-09-18 14:08:20 -07:00
Jedrzej Kosinski 46fdd636de
Merge pull request #9545 from bigcat88/asset-management
[Assets] Initial implementation
2025-09-18 14:07:18 -07:00
bigcat88 283cd27bdc
final adjustments 2025-09-18 10:05:32 +03:00
bigcat88 1a37d1476d
refactor(6): fully batched initial scan 2025-09-17 20:29:29 +03:00
bigcat88 f9602457d6
optimization: initial scan speed(batching metadata[filename]) 2025-09-17 16:47:27 +03:00
bigcat88 85ef08449d
optimization: initial scan speed(batching tags) 2025-09-17 14:08:57 +03:00
bigcat88 5b6810a2c6
fixed hash calculation during model loading in ComfyUI 2025-09-17 13:25:56 +03:00
bigcat88 621faaa195
refactor(5): use less DB queries to create seed asset 2025-09-17 10:46:21 +03:00
bigcat88 d0aa64d57b
refactor(4): use one query to init DB with all tags for assets 2025-09-16 21:18:18 +03:00
bigcat88 677a0e2508
refactor(3): unite logic for Asset fast check 2025-09-16 20:29:50 +03:00
bigcat88 31ec744317
refactor(2)/fix: skip double checking the existing files during fast check 2025-09-16 19:50:21 +03:00
bigcat88 a336c7c165
refactor(1): use general fast_asset_file_check helper for fast check 2025-09-16 19:19:18 +03:00
bigcat88 77332d3054
optimization: fast scan: commit to the DB in chunks 2025-09-16 14:21:40 +03:00
bigcat88 24a95f5ca4
removed default scanning of "input" and "output" folders; added separate endpoint for test suite. 2025-09-16 11:28:29 +03:00
bigcat88 0be513b213
fix: escape "_" symbol in all other places 2025-09-15 20:26:48 +03:00
bigcat88 f1fb7432a0
fix+test: escape "_" symbol in assets filtering 2025-09-15 19:19:47 +03:00
bigcat88 f3cf99d10c
fix+test: escape "_" symbol in tags filtering 2025-09-15 17:29:27 +03:00
bigcat88 5f187fe6fb
optimization: make list_unhashed_candidates_under_prefixes single-query instead of N+1 2025-09-15 12:46:35 +03:00
bigcat88 025fc49b4e
optimization: DB Queries (Tags) 2025-09-15 10:26:13 +03:00
bigcat88 7becb84341
fixed tests on SQLite file 2025-09-14 23:01:17 +03:00
bigcat88 dda31de690
rework: AssetInfo.name is only a display name 2025-09-14 21:53:44 +03:00
bigcat88 1d970382f0
added final tests 2025-09-14 20:02:28 +03:00
bigcat88 a2fc2bbae4
corrected formatting 2025-09-14 18:12:00 +03:00
bigcat88 a7f2546558
fix: use ".rowcount" instead of ".returning" on SQLite 2025-09-14 17:55:02 +03:00
bigcat88 6cfa94ec58
fixed metadata[filename] feature + new tests for this 2025-09-14 16:28:14 +03:00
bigcat88 a2ec1f7637
simplify code 2025-09-14 15:31:42 +03:00
bigcat88 0b795dc7a7
removed non-needed code 2025-09-14 15:14:24 +03:00
bigcat88 47f7c7ee8c
rework + add test for concurrent AssetInfo delete 2025-09-14 15:08:29 +03:00
bigcat88 cdd8d16075
+2 tests for checking Asset downloading logic 2025-09-14 14:57:24 +03:00
bigcat88 37b81e6658
fixed new PgSQL bug 2025-09-14 14:30:38 +03:00
bigcat88 975650060f
concurrency upload test + fixed 2 related bugs 2025-09-14 09:39:23 +03:00
bigcat88 4a713654cd
added more tests for the Assets logic 2025-09-14 09:10:59 +03:00
bigcat88 9b8e88ba6e
added more tests for the Assets logic 2025-09-13 20:09:45 +03:00
bigcat88 bb9ed04758
global refactoring; add support for Assets without the computed hash 2025-09-13 16:39:08 +03:00
Alexander Piskun 934377ac1e removed currently unnecessary "asset_locations" functionality 2025-09-12 14:46:22 +03:00
Alexander Piskun 3c9bf39c20
Merge pull request #1 from bigcat88/asset-management-ci
fix bugs + GH CI tests
2025-09-10 16:31:11 +03:00
bigcat88 0df1ccac6f
GitHub CI test for Assets 2025-09-10 16:22:22 +03:00
bigcat88 72548a8ac4
added additional tests; sorted tests 2025-09-10 10:39:55 +03:00
bigcat88 6eaed072c7
add some logic tests 2025-09-10 09:51:06 +03:00
bigcat88 a9096f6c97
removed non-needed code, fix tests, +1 new test 2025-09-09 20:54:11 +03:00
bigcat88 964de8a8ad
add more list_assets tests + fix one found bug 2025-09-09 20:35:18 +03:00
bigcat88 1886f10e19
add download tests 2025-09-09 19:30:58 +03:00
bigcat88 357193f7b5
fixed metadata filtering + tests 2025-09-09 19:12:11 +03:00
bigcat88 0ef73e95fd
fixed validation error + more tests 2025-09-09 16:02:39 +03:00
bigcat88 faa1e4de17
fixed another test 2025-09-09 15:17:03 +03:00
bigcat88 dfb5703d40
feat: remove Asset when there is no references left + bugfixes + more tests 2025-09-09 15:10:07 +03:00
bigcat88 0e9de2b7c9
feat: add first test 2025-09-08 20:43:45 +03:00
bigcat88 e3311c9229
feat: support for in-memory SQLite databases 2025-09-08 18:15:09 +03:00
bigcat88 3fa0fc496c
fix: use UPSERT to eliminate rare race condition during ingesting many small files in parallel 2025-09-08 18:13:32 +03:00
bigcat88 6282d495ca
corrected detection of missing files for assets 2025-09-07 22:08:38 +03:00
bigcat88 b8ef9bb92c
add detection of the missing files for existing assets 2025-09-07 16:49:39 +03:00
bigcat88 2d9be462d3
add support for assets duplicates 2025-09-06 19:22:51 +03:00
bigcat88 789a62ce35
assume that DB packages always present; refactoring & cleanup 2025-09-06 17:44:01 +03:00
bigcat88 84384ca0b4
temporary restore ModelManager 2025-09-05 23:02:26 +03:00
bigcat88 ce270ba090
added Assets Autoscan feature 2025-09-05 17:46:09 +03:00
bigcat88 bf8363ec87
always autofill "filename" in the metadata 2025-08-29 19:48:42 +03:00
bigcat88 6b86be320a
use UUID instead of autoincrement Integer for Assets ID field 2025-08-28 08:22:54 +03:00
bigcat88 bdf4ba24ce
removed not needed "assets.updated_at" column 2025-08-27 21:58:17 +03:00
bigcat88 871e41aec6
removed not needed "refcount" column 2025-08-27 21:36:31 +03:00
bigcat88 eb7008a4d3
removed not used "added_by" column 2025-08-27 21:26:35 +03:00
bigcat88 0379eff0b5
allow Upload Asset endpoint to accept hash (as documentation requires) 2025-08-27 21:18:26 +03:00
bigcat88 026b7f209c
add "--multi-user" support 2025-08-27 19:47:55 +03:00
bigcat88 7c1b0be496
add Get Asset endpoint 2025-08-27 09:58:12 +03:00
bigcat88 6fade5da38
add AssetsResolver support 2025-08-26 20:58:04 +03:00
bigcat88 a763cbd39d
add upload asset endpoint 2025-08-25 16:35:29 +03:00
bigcat88 09dabf95bc
refactoring: use the same code for "scan task" and realtime DB population 2025-08-25 13:31:56 +03:00
bigcat88 d7464e9e73
implemented assets scaner 2025-08-24 19:29:21 +03:00
bigcat88 a82577f64a
auto-creation of tags and fixed population DB when cloned asset is already present 2025-08-24 16:36:01 +03:00
bigcat88 f2ea0bc22c
added create_asset_from_hash endpoint 2025-08-24 14:15:21 +03:00
bigcat88 0755e5320a
remove timezone; download asset, delete asset endpoints 2025-08-24 12:36:20 +03:00
bigcat88 8d46bec951
use Pydantic for output; finished Tags endpoints 2025-08-24 11:02:30 +03:00
bigcat88 5c1b5973ac
dev: refactor; populate models in more nodes; use Pydantic in endpoints for input validation 2025-08-23 20:14:22 +03:00
bigcat88 f92307cd4c
dev: Everything is Assets 2025-08-23 19:21:52 +03:00
Jedrzej Kosinski c708d0a433 Merge branch 'master' into asset-management 2025-08-18 12:12:15 -07:00
Jedrzej Kosinski 1aa089e0b6 More progress on brainstorming code for asset management for models 2025-08-12 17:59:16 -07:00
Jedrzej Kosinski f032c1a50a Brainstorming abstraction for asset management stuff 2025-08-11 22:25:42 -07:00
Jedrzej Kosinski 3089936a2c Merge branch 'master' into pysssss-model-db 2025-08-11 14:09:21 -07:00
Jedrzej Kosinski cd679129e3 Merge branch 'master' into pysssss-model-db 2025-08-07 21:12:30 -07:00
pythongosssss d7062277a7 fix bad merge 2025-08-03 16:40:27 +01:00
pythongosssss 54cf14cbbb Merge remote-tracking branch 'origin/master' into pysssss-model-db 2025-08-03 16:36:49 +01:00
pythongosssss 7d5160f92c Tidy 2025-06-01 15:45:15 +01:00
pythongosssss 7f7b3f1695 tidy 2025-06-01 15:41:00 +01:00
pythongosssss 9da6aca0d0 Add additional db model metadata fields and model downloading function 2025-06-01 15:32:13 +01:00
pythongosssss 1cb3c98947 Implement database & model hashing 2025-06-01 15:32:02 +01:00
52 changed files with 8189 additions and 214 deletions

173
.github/workflows/test-assets.yml vendored Normal file
View File

@ -0,0 +1,173 @@
name: Asset System Tests
on:
push:
paths:
- 'app/**'
- 'tests-assets/**'
- '.github/workflows/test-assets.yml'
- 'requirements.txt'
pull_request:
branches: [master]
workflow_dispatch:
permissions:
contents: read
env:
PIP_DISABLE_PIP_VERSION_CHECK: '1'
PYTHONUNBUFFERED: '1'
jobs:
sqlite:
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
sqlite_mode: ['memory', 'file']
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for SQLite
id: setdb
shell: bash
run: |
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
else
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
mkdir -p "$(dirname "$DBFILE")"
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
fi
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn
postgres:
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
runs-on: ubuntu-latest
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.12']
pgsql: ['16', '18']
services:
postgres:
image: postgres:${{ matrix.pgsql }}
env:
POSTGRES_DB: assets
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U postgres -d assets"
--health-interval 10s
--health-timeout 5s
--health-retries 12
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
python -m pip install -U pip wheel
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest pytest-aiohttp pytest-asyncio
pip install greenlet psycopg
- name: Set deterministic test base dir
id: basedir
shell: bash
run: |
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
mkdir -p "$BASE/logs"
echo "ASSETS_TEST_BASE_DIR=$BASE"
- name: Set DB URL for PostgreSQL
shell: bash
run: |
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
- name: Run tests
run: python -m pytest tests-assets
- name: Show ComfyUI logs
if: always()
shell: bash
run: |
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
ls -la "$ASSETS_TEST_LOGS" || true
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
if [ -f "$f" ]; then
echo "----- BEGIN $f -----"
sed -n '1,400p' "$f"
echo "----- END $f -----"
fi
done
- name: Upload ComfyUI logs
if: always()
uses: actions/upload-artifact@v4
with:
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
if-no-files-found: warn

View File

@ -3,7 +3,7 @@
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
script_location = app/alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time

View File

@ -2,13 +2,12 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from app.assets.database.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,

View File

@ -0,0 +1,175 @@
"""initial assets schema
Revision ID: 0001_assets
Revises:
Create Date: 2025-08-20 00:00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

4
app/assets/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from .api.routes import register_assets_system
from .scanner import sync_seed_assets
__all__ = ["sync_seed_assets", "register_assets_system"]

225
app/assets/_helpers.py Normal file
View File

@ -0,0 +1,225 @@
import contextlib
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Optional, Sequence
import folder_paths
from .api import schemas_in
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> Optional[str]:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None
def new_scan_id(root: schemas_in.RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

544
app/assets/api/routes.py Normal file
View File

@ -0,0 +1,544 @@
import contextlib
import logging
import os
import urllib.parse
import uuid
from typing import Optional
from aiohttp import web
from pydantic import ValidationError
import folder_paths
from ... import user_manager
from .. import manager, scanner
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
USER_MANAGER: Optional[user_manager.UserManager] = None
LOGGER = logging.getLogger(__name__)
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = await manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
qp = request.rel_url.query
query_dict = {}
if "include_tags" in qp:
query_dict["include_tags"] = qp.getall("include_tags")
if "exclude_tags" in qp:
query_dict["exclude_tags"] = qp.getall("exclude_tags")
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
v = qp.get(k)
if v is not None:
query_dict[k] = v
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = await manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = await manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = await manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: Optional[str] = None
tags_raw: list[str] = []
provided_name: Optional[str] = None
user_metadata_raw: Optional[str] = None
provided_hash: Optional[str] = None
provided_hash_exists: Optional[bool] = None
file_written = 0
tmp_path: Optional[str] = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = await manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = await manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = await manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
async def set_asset_preview(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.SetPreviewBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.set_asset_preview(
asset_info_id=asset_info_id,
preview_asset_id=body.preview_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (PermissionError, ValueError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = await manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
LOGGER.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as ve:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
status=400,
)
result = await manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
LOGGER.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
try:
await scanner.sync_seed_assets(body.roots)
except Exception:
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response({"synced": True, "roots": body.roots}, status=200)
@ROUTES.post("/api/assets/scan/schedule")
async def schedule_asset_scan(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
states = await scanner.schedule_scans(body.roots)
return web.json_response(states.model_dump(mode="json"), status=202)
@ROUTES.get("/api/assets/scan")
async def get_asset_scan_status(request: web.Request) -> web.Response:
root = request.query.get("root", "").strip().lower()
states = scanner.current_statuses()
if root in {"models", "input", "output"}:
states = [s for s in states.scans if s.root == root] # type: ignore
states = schemas_out.AssetScanStatusResponse(scans=states)
return web.json_response(states.model_dump(mode="json"), status=200)
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})

View File

@ -0,0 +1,297 @@
import json
import uuid
from typing import Any, Literal, Optional
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: Optional[str] = None
# Accept either a JSON string (query param) or a dict
metadata_filter: Optional[dict[str, Any]] = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class UpdateAssetBody(BaseModel):
name: Optional[str] = None
tags: Optional[list[str]] = None
user_metadata: Optional[dict[str, Any]] = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.tags is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, tags, user_metadata.")
if self.tags is not None:
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
raise ValueError("Field 'tags' must be an array of strings.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
v = v.strip()
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: Optional[str] = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@ -0,0 +1,115 @@
from datetime import datetime
from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
preview_url: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: Optional[str]
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: Optional[str] = None
created_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: Optional[datetime], _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class AssetScanError(BaseModel):
path: str
message: str
at: Optional[str] = Field(None, description="ISO timestamp")
class AssetScanStatus(BaseModel):
scan_id: str
root: Literal["models", "input", "output"]
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
scheduled_at: Optional[str] = None
started_at: Optional[str] = None
finished_at: Optional[str] = None
discovered: int = 0
processed: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)
class AssetScanStatusResponse(BaseModel):
scans: list[AssetScanStatus] = Field(default_factory=list)

View File

View File

@ -0,0 +1,25 @@
from .bulk_ops import seed_from_paths_batch
from .escape_like import escape_like_prefix
from .fast_check import fast_asset_file_check
from .filters import apply_metadata_filter, apply_tag_filters
from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv
from .tags import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
remove_missing_tag_for_asset_id,
)
__all__ = [
"apply_tag_filters",
"apply_metadata_filter",
"escape_like_prefix",
"fast_asset_file_check",
"is_scalar",
"project_kv",
"ensure_tags_exist",
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"seed_from_paths_batch",
"visible_owner_clause",
]

View File

@ -0,0 +1,230 @@
import os
import uuid
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
from ..timeutil import utcnow
MAX_BIND_PARAMS = 800
async def seed_from_paths_batch(
session: AsyncSession,
*,
specs: Sequence[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
dialect = session.bind.dialect.name
if dialect not in ("sqlite", "postgresql"):
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
await session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
else:
ins_state = (
d_pg.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
else:
ins_info = (
d_pg.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
async def bulk_insert_tags_and_meta(
session: AsyncSession,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
dialect = session.bind.dialect.name
if tag_rows:
if dialect == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
await session.execute(ins_links, chunk)
if meta_rows:
if dialect == "sqlite":
ins_meta = (
d_sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif dialect == "postgresql":
ins_meta = (
d_pg.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
await session.execute(ins_meta, chunk)
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))

View File

@ -0,0 +1,7 @@
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape

View File

@ -0,0 +1,19 @@
import os
from typing import Optional
def fast_asset_file_check(
*,
mtime_db: Optional[int],
size_db: Optional[int],
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True

View File

@ -0,0 +1,87 @@
from typing import Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Optional[Sequence[str]],
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -0,0 +1,12 @@
import sqlalchemy as sa
from ..models import AssetInfo
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@ -0,0 +1,64 @@
from decimal import Decimal
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -0,0 +1,90 @@
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..._helpers import normalize_tags
from ..models import AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def add_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def remove_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
) -> None:
await session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@ -0,0 +1,251 @@
import uuid
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
from .timeutil import utcnow
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list["AssetCacheState"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Optional[Asset]] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@ -0,0 +1,57 @@
from .content import (
check_fs_asset_exists_quick,
compute_hash_and_dedup_for_cache_state,
ingest_fs_asset,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset,
touch_asset_infos_by_fs_path,
)
from .info import (
add_tags_to_asset_info,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_tags,
list_asset_infos_page,
list_tags_with_usage,
remove_tags_from_asset_info,
replace_asset_info_metadata_projection,
set_asset_info_preview,
set_asset_info_tags,
touch_asset_info_by_id,
update_asset_info_full,
)
from .queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
get_cache_state_by_asset_id,
list_cache_states_by_asset_id,
pick_best_live_path,
)
__all__ = [
# queries
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
"get_cache_state_by_asset_id",
"list_cache_states_by_asset_id",
"pick_best_live_path",
# info
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
"update_asset_info_full", "replace_asset_info_metadata_projection",
"touch_asset_info_by_id", "delete_asset_info_by_id",
"add_tags_to_asset_info", "remove_tags_from_asset_info",
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content
"check_fs_asset_exists_quick",
"redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
"list_cache_states_with_asset_under_prefixes",
]

View File

@ -0,0 +1,721 @@
import contextlib
import logging
import os
from datetime import datetime
from typing import Any, Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from ..._helpers import compute_relative_filename
from ...storage import hashing as hashing_mod
from ..helpers import (
ensure_tags_exist,
escape_like_prefix,
remove_missing_tag_for_asset_id,
)
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
from .info import replace_asset_info_metadata_projection
from .queries import list_cache_states_by_asset_id, pick_best_live_path
async def check_fs_asset_exists_quick(
session: AsyncSession,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
locator = os.path.abspath(file_path)
stmt = (
sa.select(sa.literal(True))
.select_from(AssetCacheState)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(
AssetCacheState.file_path == locator,
Asset.hash.isnot(None),
AssetCacheState.needs_verify.is_(False),
)
.limit(1)
)
conds = []
if mtime_ns is not None:
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
if conds:
stmt = stmt.where(*conds)
return (await session.execute(stmt)).first() is not None
async def redirect_all_references_then_delete_asset(
session: AsyncSession,
*,
duplicate_asset_id: str,
canonical_asset_id: str,
) -> None:
"""
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
- Otherwise, simply repoint the AssetInfo.asset_id.
- Always retarget AssetCacheState rows.
- Finally delete the duplicate Asset row.
"""
if duplicate_asset_id == canonical_asset_id:
return
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
dup_infos = (
await session.execute(
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
)
).unique().scalars().all()
for info in dup_infos:
# Try to find an existing collision on canonical
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == canonical_asset_id,
AssetInfo.owner_id == info.owner_id,
AssetInfo.name == info.name,
)
.limit(1)
)
).unique().scalars().first()
if existing:
merged_meta = dict(existing.user_metadata or {})
other_meta = info.user_metadata or {}
for k, v in other_meta.items():
if k not in merged_meta:
merged_meta[k] = v
if merged_meta != (existing.user_metadata or {}):
await replace_asset_info_metadata_projection(
session,
asset_info_id=existing.id,
user_metadata=merged_meta,
)
existing_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
)
).all()
}
from_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
)
).all()
}
to_add = sorted(from_tags - existing_tags)
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
now = utcnow()
session.add_all([
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
for t in to_add
])
await session.flush()
if existing.preview_id is None and info.preview_id is not None:
existing.preview_id = info.preview_id
if info.last_access_time and (
existing.last_access_time is None or info.last_access_time > existing.last_access_time
):
existing.last_access_time = info.last_access_time
existing.updated_at = utcnow()
await session.flush()
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
await session.delete(info)
await session.flush()
else:
# Simple retarget
info.asset_id = canonical_asset_id
info.updated_at = utcnow()
await session.flush()
# 2) Repoint cache states and previews
await session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.asset_id == duplicate_asset_id)
.values(asset_id=canonical_asset_id)
)
await session.execute(
sa.update(AssetInfo)
.where(AssetInfo.preview_id == duplicate_asset_id)
.values(preview_id=canonical_asset_id)
)
# 3) Remove duplicate Asset
dup = await session.get(Asset, duplicate_asset_id)
if dup:
await session.delete(dup)
await session.flush()
async def compute_hash_and_dedup_for_cache_state(
session: AsyncSession,
*,
state_id: int,
) -> Optional[str]:
"""
Compute hash for the given cache state, deduplicate, and settle verify cases.
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
"""
state = await session.get(AssetCacheState, state_id)
if not state:
return None
path = state.file_path
try:
if not os.path.isfile(path):
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
asset = await session.get(Asset, state.asset_id)
await session.delete(state)
await session.flush()
if asset and asset.hash is None:
remaining = (
await session.execute(
sa.select(sa.func.count())
.select_from(AssetCacheState)
.where(AssetCacheState.asset_id == asset.id)
)
).scalar_one()
if int(remaining or 0) == 0:
await session.delete(asset)
await session.flush()
else:
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
return None
digest = await hashing_mod.blake3_hash(path)
new_hash = f"blake3:{digest}"
st = os.stat(path, follow_symlinks=True)
new_size = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
# Current asset of this state
this_asset = await session.get(Asset, state.asset_id)
# If the state got orphaned somehow (race), just reattach appropriately.
if not this_asset:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
state.asset_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
state.asset_id = new_asset.id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
await session.flush()
return state.asset_id
# 1) Seed asset case (hash is NULL): claim or merge into canonical
if this_asset.hash is None:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
# Merge seed asset into canonical (safe, collision-aware)
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# No canonical: try to claim the hash; handle races with a SAVEPOINT
try:
async with session.begin_nested():
this_asset.hash = new_hash
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
await session.flush()
except IntegrityError:
# Someone else claimed it concurrently; fetch canonical and merge
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
await session.flush()
return canonical.id
# If we got here, the integrity error was not about hash uniqueness
raise
# Claimed successfully
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# 2) Verify case for hashed assets
if this_asset.hash == new_hash:
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
await session.flush()
return this_asset.id
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
target_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
target_id = new_asset.id
state.asset_id = target_id
state.mtime_ns = mtime_ns
state.needs_verify = False
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
await session.flush()
return target_id
except Exception:
raise
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
if session.bind.dialect.name == "postgresql":
stmt = (
sa.select(AssetCacheState.id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
.distinct(AssetCacheState.asset_id)
)
else:
first_id = sa.func.min(AssetCacheState.id).label("first_id")
stmt = (
sa.select(first_id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None), path_filter)
.group_by(AssetCacheState.asset_id)
.order_by(first_id.asc())
)
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
async def list_verify_candidates_under_prefixes(
session: AsyncSession, *, prefixes: Sequence[str]
) -> Union[list[int], Sequence[int]]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
return (
await session.execute(
sa.select(AssetCacheState.id)
.where(AssetCacheState.needs_verify.is_(True))
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: str = "",
preview_id: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
dialect = session.bind.dialect.name
if preview_id:
if not await session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
if dialect == "sqlite":
res = await session.execute(
d_sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
elif dialect == "postgresql":
res = await session.execute(
d_pg.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(
index_elements=[Asset.hash],
index_where=Asset.__table__.c.hash.isnot(None),
)
.returning(Asset.id)
)
inserted_id = res.scalar_one_or_none()
if inserted_id:
out["asset_created"] = True
asset = await session.get(Asset, inserted_id)
else:
asset = (
await session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
res = await session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = await session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
async with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
await session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
await session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
await session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
await ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
await session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
file_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
locator = os.path.abspath(file_path)
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(AssetCacheState)
.where(
AssetCacheState.asset_id == AssetInfo.asset_id,
AssetCacheState.file_path == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
await session.execute(stmt.values(last_access_time=ts))
async def list_cache_states_with_asset_under_prefixes(
session: AsyncSession,
*,
prefixes: Sequence[str],
) -> list[tuple[AssetCacheState, Optional[str], int]]:
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
if not prefixes:
return []
conds = []
for p in prefixes:
if not p:
continue
base = os.path.abspath(p)
if not base.endswith(os.sep):
base = base + os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
if not conds:
return []
rows = (
await session.execute(
select(AssetCacheState, Asset.hash, Asset.size_bytes)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).all()
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
try:
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
if not primary_path:
return
new_filename = compute_relative_filename(primary_path)
if not new_filename:
return
infos = (
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
).scalars().all()
for info in infos:
current_meta = info.user_metadata or {}
if current_meta.get("filename") == new_filename:
continue
updated = dict(current_meta)
updated["filename"] = new_filename
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
except Exception:
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)

View File

@ -0,0 +1,586 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, noload
from ..._helpers import compute_relative_filename, normalize_tags
from ..helpers import (
apply_metadata_filter,
apply_tag_filters,
ensure_tags_exist,
escape_like_prefix,
project_kv,
visible_owner_clause,
)
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from ..timeutil import utcnow
from .queries import (
get_asset_by_hash,
list_cache_states_by_asset_id,
pick_best_live_path,
)
async def list_asset_infos_page(
session: AsyncSession,
*,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((await session.execute(count_stmt)).scalar_one() or 0)
infos = (await session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def fetch_asset_info_and_asset(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = await session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
async def fetch_asset_info_asset_and_tags(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (await session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
async def create_asset_info_for_existing_asset(
session: AsyncSession,
*,
asset_hash: str,
name: str,
user_metadata: Optional[dict] = None,
tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
async with session.begin_nested():
session.add(info)
await session.flush()
except IntegrityError:
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: str,
user_metadata: Optional[dict],
) -> None:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def touch_asset_info_by_id(
session: AsyncSession,
*,
asset_info_id: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
await session.execute(stmt.values(last_access_time=ts))
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((await session.execute(stmt)).rowcount or 0) > 0
async def add_tags_to_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
await ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
async with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
await session.flush()
except IntegrityError:
await nested.rollback()
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
async def remove_tags_from_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
await session.flush()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
async def list_tags_with_usage(
session: AsyncSession,
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (await session.execute(q.limit(limit).offset(offset))).all()
total = (await session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
async def set_asset_info_preview(
session: AsyncSession,
*,
asset_info_id: str,
preview_asset_id: Optional[str],
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not await session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
await session.flush()

View File

@ -0,0 +1,76 @@
import os
from typing import Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
row = (
await session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
return (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (await session.execute(q)).first() is not None
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_id(
session: AsyncSession, *, asset_id: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path

View File

@ -0,0 +1,6 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)

556
app/assets/manager.py Normal file
View File

@ -0,0 +1,556 @@
import contextlib
import logging
import mimetypes
import os
from typing import Optional, Sequence
from comfy_api.internal import async_to_sync
from ..db import create_session
from ._helpers import (
ensure_within_base,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from .api import schemas_in, schemas_out
from .database.models import Asset
from .database.services import (
add_tags_to_asset_info,
asset_exists_by_hash,
asset_info_exists_for_asset_id,
check_fs_asset_exists_quick,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash,
get_asset_info_by_id,
get_asset_tags,
ingest_fs_asset,
list_asset_infos_page,
list_cache_states_by_asset_id,
list_tags_with_usage,
pick_best_live_path,
remove_tags_from_asset_info,
set_asset_info_preview,
touch_asset_info_by_id,
touch_asset_infos_by_fs_path,
update_asset_info_full,
)
from .storage import hashing
async def asset_exists(*, asset_hash: str) -> bool:
async with await create_session() as session:
return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
if not size_bytes:
return
async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
await session.commit()
return
asset_hash = hashing.blake3_hash_sync(abs_path)
async with await create_session() as session:
await ingest_fs_asset(
session,
asset_hash="blake3:" + asset_hash,
abs_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=file_name,
tag_origin="automatic",
tags=tags,
)
await session.commit()
async def list_assets(
*,
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
async with await create_session() as session:
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
async def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: Optional[str] = None,
owner_id: str = "",
expected_asset_hash: Optional[str] = None,
) -> schemas_out.AssetCreated:
try:
digest = await hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
async with await create_session() as session:
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = await create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
async with await create_session() as session:
result = await ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
async def update_asset(
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
await session.commit()
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
async def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: Optional[str],
owner_id: str = "",
) -> schemas_out.AssetDetail:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
await set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
await session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
await session.commit()
return False
if not delete_content_if_orphan or not asset_id:
await session.commit()
return True
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
await session.commit()
return True
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = await session.get(Asset, asset_id)
if asset_row is not None:
await session.delete(asset_row)
await session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
async def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower()
async with await create_session() as session:
asset = await get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = await create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
async def list_tags(
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
async with await create_session() as session:
rows, total = await list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
async def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
await session.commit()
return schemas_out.TagsAdd(**data)
async def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
await session.commit()
return schemas_out.TagsRemove(**data)
def _safe_sort_field(requested: Optional[str]) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str], fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback

501
app/assets/scanner.py Normal file
View File

@ -0,0 +1,501 @@
import asyncio
import contextlib
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Literal, Optional
import sqlalchemy as sa
import folder_paths
from ..db import create_session
from ._helpers import (
collect_models_files,
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
list_tree,
new_scan_id,
prefixes_for_root,
ts_to_iso,
)
from .api import schemas_in, schemas_out
from .database.helpers import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
escape_like_prefix,
fast_asset_file_check,
remove_missing_tag_for_asset_id,
seed_from_paths_batch,
)
from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import (
compute_hash_and_dedup_for_cache_state,
list_cache_states_by_asset_id,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
)
LOGGER = logging.getLogger(__name__)
SLOW_HASH_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: schemas_in.RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
file_errors: list[dict] = field(default_factory=list)
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
def current_statuses() -> schemas_out.AssetScanStatusResponse:
scans = []
for root in schemas_in.ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
scans.append(_scan_progress_to_scan_status_model(prog))
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
results: list[ScanProgress] = []
for root in roots:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
RUNNING_TASKS[root] = asyncio.create_task(
_run_hash_verify_pipeline(root, prog, state),
name=f"asset-scan:{root}",
)
results.append(prog)
return _status_response_for(results)
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
t_total = time.perf_counter()
created = 0
skipped_existing = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as ex:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
ap = os.path.abspath(p)
if ap in existing_paths:
skipped_existing += 1
continue
try:
st = os.stat(ap, follow_symlinks=True)
except OSError:
continue
if not st.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(ap)
specs.append(
{
"abs_path": ap,
"size_bytes": st.st_size,
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(ap),
}
)
for t in tags:
tag_pool.add(t)
if not specs:
return
async with await create_session() as sess:
if tag_pool:
await ensure_tags_exist(sess, tag_pool, tag_type="user")
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
await sess.commit()
finally:
LOGGER.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_total,
created,
skipped_existing,
len(paths),
)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=ts_to_iso(progress.scheduled_at),
started_at=ts_to_iso(progress.started_at),
finished_at=ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
prog.status = "running"
prog.started_at = time.time()
try:
prefixes = prefixes_for_root(root)
await _fast_db_consistency_pass(root)
# collect candidates from DB
async with await create_session() as sess:
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
# dedupe: prioritize verification first
seen = set()
ordered: list[int] = []
for lst in (verify_ids, unhashed_ids):
for sid in lst:
if sid not in seen:
seen.add(sid)
ordered.append(sid)
prog.discovered = len(ordered)
# queue up work
for sid in ordered:
await state.queue.put(sid)
state.closed = True
_start_state_workers(root, prog, state)
await _await_state_workers_then_finish(root, prog, state)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
_append_error(prog, path="", message=str(exc))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None)
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
"""
Detect missing files quickly and toggle 'missing' tag per asset_id.
Rules:
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
- We consider ALL cache states of the asset (across roots) before tagging.
"""
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
elif root == "input":
bases = [folder_paths.get_input_directory()]
else:
bases = [folder_paths.get_output_directory()]
try:
async with await create_session() as sess:
# state + hash + size for the current root
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
# Track fast_ok within the scanned root and whether the asset is hashed
by_asset: dict[str, dict[str, bool]] = {}
for state, a_hash, size_db in rows:
aid = state.asset_id
acc = by_asset.get(aid)
if acc is None:
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
by_asset[aid] = acc
try:
if acc["hashed"]:
st = os.stat(state.file_path, follow_symlinks=True)
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
acc["any_fast_ok_here"] = True
except FileNotFoundError:
pass
except OSError as e:
_append_error(prog, path=state.file_path, message=str(e))
# Decide per asset, considering ALL its states (not just this root)
for aid, acc in by_asset.items():
try:
if not acc["hashed"]:
# Never tag seed assets as missing
continue
any_fast_ok_global = acc["any_fast_ok_here"]
if not any_fast_ok_global:
# Check other states outside this root
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
for st in others:
try:
any_fast_ok_global = fast_asset_file_check(
mtime_db=st.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(st.file_path, follow_symlinks=True),
)
except OSError:
continue
if any_fast_ok_global:
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
else:
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception as ex:
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
await sess.commit()
except Exception as e:
_append_error(prog, path="", message=f"reconcile failed: {e}")
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
if state.workers:
return
async def _worker(_wid: int):
while True:
sid = await state.queue.get()
try:
if sid is None:
return
try:
async with await create_session() as sess:
# Optional: fetch path for better error messages
st = await sess.get(AssetCacheState, sid)
try:
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
await sess.commit()
except Exception as e:
path = st.file_path if st else f"state:{sid}"
_append_error(prog, path=path, message=str(e))
raise
except Exception:
pass
finally:
prog.processed += 1
finally:
state.queue.task_done()
state.workers = [
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
for i in range(SLOW_HASH_CONCURRENCY)
]
async def _close_when_ready():
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_ready())
async def _await_state_workers_then_finish(
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
await _reconcile_missing_tags_for_root(root, prog)
prog.finished_at = time.time()
prog.status = "completed"
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"at": ts_to_iso(time.time()),
})
async def _fast_db_consistency_pass(
root: schemas_in.RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> Optional[set[str]]:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
async with await create_session() as sess:
rows = (
await sess.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = await sess.get(Asset, aid)
if asset:
await sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
await sess.commit()
return survivors if collect_existing_paths else None

View File

View File

@ -0,0 +1,72 @@
import asyncio
import os
from typing import IO, Union
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
"""Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = None
if hasattr(file_obj, "tell"):
orig_pos = file_obj.tell()
try:
if hasattr(file_obj, "seek"):
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if hasattr(file_obj, "seek") and orig_pos is not None:
file_obj.seek(orig_pos)
def blake3_hash_sync(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
if hasattr(fp, "read"):
return _hash_file_obj_sync(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
async def blake3_hash(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
return await asyncio.to_thread(_worker)

View File

@ -1,112 +0,0 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

View File

@ -1,14 +0,0 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

255
app/db.py Normal file
View File

@ -0,0 +1,255 @@
import logging
import os
import shutil
from contextlib import asynccontextmanager
from typing import Optional
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from comfy.cli_args import args
LOGGER = logging.getLogger(__name__)
ENGINE: Optional[AsyncEngine] = None
SESSION: Optional[async_sessionmaker] = None
def _root_paths():
"""Resolve alembic.ini and migrations script folder."""
root_path = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
return config_path, scripts_path
def _absolutize_sqlite_url(db_url: str) -> str:
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
try:
u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
db_path: str = u.database or ""
if isinstance(db_path, str) and db_path.startswith("file:"):
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
"""
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
Returns: (normalized_url, is_memory)
"""
try:
u = make_url(db_url)
except Exception:
return db_url, False
if not u.drivername.startswith("sqlite"):
return db_url, False
db = u.database or ""
if db == ":memory:":
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
return str(u), True
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
if "uri=true" not in db:
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
return str(u), True
return str(u), False
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
"""Return the on-disk path for a SQLite URL, else None."""
try:
u = make_url(sync_url)
except Exception:
return None
if not u.drivername.startswith("sqlite"):
return None
db_path = u.database
if isinstance(db_path, str) and db_path.startswith("file:"):
return None # Not a real file if it is a URI like "file:...?"
return db_path
def _get_alembic_config(sync_url: str) -> Config:
"""Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
db_url = _absolutize_sqlite_url(db_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
if is_mem:
connect_args["uri"] = True
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
if not is_mem:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(database_url=db_url, connect_args=connect_args)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(database_url: str, connect_args: dict) -> None:
if database_url.find("postgresql+psycopg") == -1:
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(database_url)
driver = u.drivername
if not driver.startswith("sqlite+aiosqlite"):
raise ValueError(f"Unsupported DB driver: {driver}")
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
database_url = _absolutize_sqlite_url(database_url)
cfg = _get_alembic_config(database_url)
engine = create_engine(database_url, future=True, connect_args=connect_args)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head()
if target_rev is None:
LOGGER.warning("Alembic: no target revision found.")
return
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(database_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try:
shutil.copy(sqlite_path, backup_path)
except Exception as exc:
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
try:
command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path)
except Exception as re:
LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
LOGGER.exception("Error upgrading database, backup is not available.")
raise
def get_engine():
"""Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@ -196,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
@ -205,6 +216,15 @@ class FrontendManager:
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@ -225,6 +245,15 @@ comfyui-frontend-package is not installed.
@classmethod
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
Returns:
str: The path to the workflow templates directory.
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@ -260,11 +289,16 @@ comfyui-workflow-templates is not installed.
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
tuple[str, str, str]: A tuple containing (owner, repo, version).
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
@ -281,18 +315,22 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initializes the frontend for the specified version.
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@ -344,13 +382,17 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Args:
version_string (str): The version string to initialize the frontend with.
version_string (str): The version string specifying which frontend to use.
Returns:
str: The path of the initialized frontend.
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)

View File

@ -212,7 +212,8 @@ parser.add_argument(
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@ -50,10 +50,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
@ -66,6 +72,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@ -93,6 +101,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
# populate_db_with_asset(ckpt) # surprise tool that can help us later - performs hashing on model file
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@ -392,6 +392,20 @@ class MultiCombo(ComfyTypeI):
})
return to_return
@comfytype(io_type="ASSET")
class Asset(ComfyTypeI):
class Input(WidgetInput):
def __init__(self, id: str, query_tags: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.query_tags = query_tags
def as_dict(self):
to_return = super().as_dict() | prune_dict({
"query_tags": self.query_tags
})
return to_return
@comfytype(io_type="IMAGE")
class Image(ComfyTypeIO):
Type = torch.Tensor

View File

@ -279,10 +279,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@ -295,6 +292,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return full_path
elif os.path.islink(full_path):
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
elif allow_missing:
return full_path
return None
@ -309,6 +308,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
return full_path
def get_relative_path(full_path: str) -> tuple[str, str] | None:
"""Convert a full path back to a type-relative path.
Args:
full_path: The full path to the file
Returns:
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
"""
global folder_names_and_paths
full_path = os.path.normpath(full_path)
for model_type, (paths, _) in folder_names_and_paths.items():
for base_path in paths:
base_path = os.path.normpath(base_path)
if full_path.startswith(base_path):
relative_path = os.path.relpath(full_path, base_path)
return model_type, relative_path
return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths

17
main.py
View File

@ -164,7 +164,6 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_type = execution.CacheType.CLASSIC
@ -279,14 +278,13 @@ def cleanup_temp():
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
async def setup_database():
from app.assets import sync_seed_assets
from app.db import init_db_engine
def setup_database():
try:
from app.database.db import init_db, dependencies_available
if dependencies_available():
init_db()
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
await init_db_engine()
if not args.disable_assets_autoscan:
await sync_seed_assets(["models"])
def start_comfyui(asyncio_loop=None):
@ -312,6 +310,8 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
asyncio_loop.run_until_complete(setup_database())
hook_breaker_ac10a0.save_functions()
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
@ -320,7 +320,6 @@ def start_comfyui(asyncio_loop=None):
hook_breaker_ac10a0.restore_functions()
cuda_malloc_warning()
setup_database()
prompt_server.add_routes()
hijack_progress(prompt_server)

View File

@ -20,7 +20,9 @@ tqdm
psutil
alembic
SQLAlchemy
aiosqlite
av>=14.2.0
blake3
#non essential dependencies:
kornia>=0.7.1

View File

@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from app.assets import sync_seed_assets, register_assets_system
from protocol import BinaryEventTypes
# Import cache control middleware
@ -178,6 +179,7 @@ class PromptServer():
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_system(self.app, self.user_manager)
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
@ -622,6 +624,7 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
await sync_seed_assets(["models"])
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:

307
tests-assets/conftest.py Normal file
View File

@ -0,0 +1,307 @@
import asyncio
import contextlib
import json
import os
import socket
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import AsyncIterator, Callable, Optional
import aiohttp
import pytest
import pytest_asyncio
def pytest_addoption(parser: pytest.Parser) -> None:
"""
Allow overriding the database URL used by the spawned ComfyUI process.
Priority:
1) --db-url command line option
2) ASSETS_TEST_DB_URL environment variable (used by CI)
3) default: sqlite in-memory
"""
parser.addoption(
"--db-url",
action="store",
default=os.environ.get("ASSETS_TEST_DB_URL", "sqlite+aiosqlite:///:memory:"),
help="Async SQLAlchemy DB URL (e.g. sqlite+aiosqlite:///:memory: or postgresql+psycopg://user:pass@host/db)",
)
def _free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _make_base_dirs(root: Path) -> None:
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
(root / sub).mkdir(parents=True, exist_ok=True)
async def _wait_http_ready(base: str, session: aiohttp.ClientSession, timeout: float = 90.0) -> None:
start = time.time()
last_err = None
while time.time() - start < timeout:
try:
async with session.get(base + "/api/assets") as r:
if r.status in (200, 400):
return
except Exception as e:
last_err = e
await asyncio.sleep(0.25)
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
def comfy_tmp_base_dir() -> Path:
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
created_by_fixture = False
if env_base:
tmp = Path(env_base)
tmp.mkdir(parents=True, exist_ok=True)
else:
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
created_by_fixture = True
_make_base_dirs(tmp)
yield tmp
if created_by_fixture:
with contextlib.suppress(Exception):
for p in sorted(tmp.rglob("*"), reverse=True):
if p.is_file() or p.is_symlink():
p.unlink(missing_ok=True)
for p in sorted(tmp.glob("**/*"), reverse=True):
with contextlib.suppress(Exception):
p.rmdir()
tmp.rmdir()
@pytest.fixture(scope="session")
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
"""
Boot ComfyUI subprocess with:
- sandbox base dir
- sqlite memory DB (default)
- autoscan disabled
Returns (base_url, process, port)
"""
port = _free_port()
db_url = request.config.getoption("--db-url")
logs_dir = comfy_tmp_base_dir / "logs"
logs_dir.mkdir(exist_ok=True)
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
comfy_root = Path(__file__).resolve().parent.parent
if not (comfy_root / "main.py").is_file():
raise FileNotFoundError(f"main.py not found under {comfy_root}")
proc = subprocess.Popen(
args=[
sys.executable,
"main.py",
f"--base-directory={str(comfy_tmp_base_dir)}",
f"--database-url={db_url}",
"--disable-assets-autoscan",
"--listen",
"127.0.0.1",
"--port",
str(port),
"--cpu",
],
stdout=out_log,
stderr=err_log,
cwd=str(comfy_root),
env={**os.environ},
)
for _ in range(50):
if proc.poll() is not None:
out_log.flush()
err_log.flush()
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
time.sleep(0.1)
base_url = f"http://127.0.0.1:{port}"
try:
async def _probe():
async with aiohttp.ClientSession() as s:
await _wait_http_ready(base_url, s, timeout=90.0)
asyncio.run(_probe())
yield base_url, proc, port
except Exception as e:
with contextlib.suppress(Exception):
proc.terminate()
proc.wait(timeout=10)
with contextlib.suppress(Exception):
out_log.flush()
err_log.flush()
raise RuntimeError(f"ComfyUI did not become ready: {e}")
if proc and proc.poll() is None:
with contextlib.suppress(Exception):
proc.terminate()
proc.wait(timeout=15)
out_log.close()
err_log.close()
@pytest_asyncio.fixture
async def http() -> AsyncIterator[aiohttp.ClientSession]:
timeout = aiohttp.ClientTimeout(total=120)
async with aiohttp.ClientSession(timeout=timeout) as s:
yield s
@pytest.fixture
def api_base(comfy_url_and_proc) -> str:
base_url, _proc, _port = comfy_url_and_proc
return base_url
async def _post_multipart_asset(
session: aiohttp.ClientSession,
base: str,
*,
name: str,
tags: list[str],
meta: dict,
data: bytes,
extra_fields: Optional[dict] = None,
) -> tuple[int, dict]:
form = aiohttp.FormData()
form.add_field("file", data, filename=name, content_type="application/octet-stream")
form.add_field("tags", json.dumps(tags))
form.add_field("name", name)
form.add_field("user_metadata", json.dumps(meta))
if extra_fields:
for k, v in extra_fields.items():
form.add_field(k, v)
async with session.post(base + "/api/assets", data=form) as r:
body = await r.json()
return r.status, body
@pytest.fixture
def make_asset_bytes() -> Callable[[str, int], bytes]:
def _make(name: str, size: int = 8192) -> bytes:
seed = sum(ord(c) for c in name) % 251
return bytes((i * 31 + seed) % 256 for i in range(size))
return _make
@pytest_asyncio.fixture
async def asset_factory(http: aiohttp.ClientSession, api_base: str):
"""
Returns create(name, tags, meta, data) -> response dict
Tracks created ids and deletes them after the test.
"""
created: list[str] = []
async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
assert status in (200, 201), body
created.append(body["id"])
return body
yield create
# cleanup by id
for aid in created:
with contextlib.suppress(Exception):
async with http.delete(f"{api_base}/api/assets/{aid}") as r:
await r.read()
@pytest_asyncio.fixture
async def seeded_asset(request: pytest.FixtureRequest, http: aiohttp.ClientSession, api_base: str) -> dict:
"""
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
Returns response dict with id, asset_hash, tags, etc.
"""
name = "unit_1_example.safetensors"
p = getattr(request, "param", {}) or {}
tags: Optional[list[str]] = p.get("tags")
if tags is None:
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
form = aiohttp.FormData()
form.add_field("file", b"A" * 4096, filename=name, content_type="application/octet-stream")
form.add_field("tags", json.dumps(tags))
form.add_field("name", name)
form.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 201, body
return body
@pytest_asyncio.fixture(autouse=True)
async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str):
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
yield
while True:
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
) as r:
body = await r.json()
if r.status != 200:
break
ids = [a["id"] for a in body.get("assets", [])]
if not ids:
break
for aid in ids:
with contextlib.suppress(Exception):
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
await dr.read()
async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint."""
async with session.post(base_url + "/api/assets/scan/seed", json={"roots": ["models", "input", "output"]}) as r:
await r.read()
await asyncio.sleep(0.1) # tiny yield to the event loop to let any final DB commits flush
@pytest_asyncio.fixture
async def run_scan_and_wait(http: aiohttp.ClientSession, api_base: str):
"""Schedule an asset scan for a given root and wait until it finishes."""
async def _run(root: str, timeout: float = 120.0):
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r:
# we ignore body; scheduling returns 202 with a status payload
await r.read()
start = time.time()
while True:
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as st:
body = await st.json()
scans = (body or {}).get("scans", [])
status = None
if scans:
status = scans[-1].get("status")
if status in {"completed", "failed", "cancelled"}:
if status != "completed":
raise RuntimeError(f"Scan for root={root} finished with status={status}")
return
if time.time() - start > timeout:
raise TimeoutError(f"Timed out waiting for scan of root={root}")
await asyncio.sleep(0.1)
return _run
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@ -0,0 +1,347 @@
import os
import uuid
from pathlib import Path
import aiohttp
import pytest
from conftest import get_asset_filename, trigger_sync_seed_assets
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_seed_asset_removed_when_file_is_deleted(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
):
"""Asset without hash (seed) whose file disappears:
after triggering sync_seed_assets, Asset + AssetInfo disappear.
"""
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
case_dir.mkdir(parents=True, exist_ok=True)
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
fp = case_dir / name
fp.write_bytes(b"Z" * 2048)
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
await trigger_sync_seed_assets(http, api_base)
# Verify it is visible via API and carries no hash (seed)
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
) as r1:
body1 = await r1.json()
assert r1.status == 200
# there should be exactly one with that name
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
assert matches
assert matches[0].get("asset_hash") is None
asset_info_id = matches[0]["id"]
# Remove the underlying file and sync again
if fp.exists():
fp.unlink()
await trigger_sync_seed_assets(http, api_base)
# It should disappear (AssetInfo and seed Asset gone)
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
) as r2:
body2 = await r2.json()
assert r2.status == 200
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
@pytest.mark.asyncio
async def test_hashed_asset_missing_tag_added_then_removed_after_scan(
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""Hashed asset with a single cache_state:
1. delete its file -> sync adds 'missing'
2. restore file -> scan removes 'missing'
"""
name = "missing_tag_test.png"
tags = ["input", "unit-tests", "msync2"]
data = make_asset_bytes(name, 4096)
a = await asset_factory(name, tags, {}, data)
# Compute its on-disk path and remove it
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
assert dest.exists(), f"Expected asset file at {dest}"
dest.unlink()
# Fast sync should add 'missing' to the AssetInfo
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{a['id']}") as g1:
d1 = await g1.json()
assert g1.status == 200, d1
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
# Restore the file with the exact same content and re-hash/verify via scan
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(data)
await run_scan_and_wait("input")
async with http.get(f"{api_base}/api/assets/{a['id']}") as g2:
d2 = await g2.json()
assert g2.status == 200, d2
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
@pytest.mark.asyncio
async def test_hashed_asset_two_asset_infos_both_get_missing(
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
):
"""Hashed asset with a single cache_state, but two AssetInfo rows:
deleting the single file then syncing should add 'missing' to both infos.
"""
# Upload one hashed asset
name = "two_infos_one_path.png"
base_tags = ["input", "unit-tests", "multiinfo"]
created = await asset_factory(name, base_tags, {}, b"A" * 2048)
# Create second AssetInfo for the same Asset via from-hash
payload = {
"hash": created["asset_hash"],
"name": "two_infos_one_path_copy.png",
"tags": base_tags, # keep it in our unit-tests scope for cleanup
"user_metadata": {"k": "v"},
}
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r2:
b2 = await r2.json()
assert r2.status == 201, b2
second_id = b2["id"]
# Remove the single underlying file
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
assert p.exists()
p.unlink()
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0:
tags0 = await r0.json()
assert r0.status == 200, tags0
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
old_missing = int(byname0.get("missing", {}).get("count", 0))
# Sync -> both AssetInfos for this asset must receive 'missing'
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{created['id']}") as ga:
da = await ga.json()
assert ga.status == 200, da
assert "missing" in set(da.get("tags", []))
async with http.get(f"{api_base}/api/assets/{second_id}") as gb:
db = await gb.json()
assert gb.status == 200, db
assert "missing" in set(db.get("tags", []))
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
tags1 = await r1.json()
assert r1.status == 200, tags1
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
new_missing = int(byname1.get("missing", {}).get("count", 0))
assert new_missing == old_missing + 2
@pytest.mark.asyncio
async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""Hashed asset with two cache_state rows:
1. delete one file -> sync should NOT add 'missing'
2. delete second file -> sync should add 'missing'
"""
name = "two_cache_states_partial_delete.png"
tags = ["input", "unit-tests", "dual"]
data = make_asset_bytes(name, 3072)
created = await asset_factory(name, tags, {}, data)
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
assert path1.exists()
# Create a second on-disk copy under the same root but different subfolder
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
path2.parent.mkdir(parents=True, exist_ok=True)
path2.write_bytes(data)
# Fast seed so the second path appears (as a seed initially)
await trigger_sync_seed_assets(http, api_base)
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
await run_scan_and_wait("input")
# Remove only one file and sync -> asset should still be healthy (no 'missing')
path1.unlink()
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{created['id']}") as g1:
d1 = await g1.json()
assert g1.status == 200, d1
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
# Baseline 'missing' usage count just before last file removal
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0:
tags0 = await r0.json()
assert r0.status == 200, tags0
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
path2.unlink()
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{created['id']}") as g2:
d2 = await g2.json()
assert g2.status == 200, d2
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
tags1 = await r1.json()
assert r1.status == 200, tags1
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
assert new_missing == old_missing + 2
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
):
"""
Fast pass alone clears 'missing' when size and mtime match exactly:
1) upload (hashed), record original mtime_ns
2) delete -> fast pass adds 'missing'
3) restore same bytes and set mtime back to the original value
4) run fast pass again -> 'missing' is removed (no slow scan)
"""
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
name = "fastpass_clear.bin"
data = make_asset_bytes(name, 3072)
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
p = base / get_asset_filename(a["asset_hash"], ".bin")
st0 = p.stat()
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
# Delete -> fast pass adds 'missing'
p.unlink()
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
d1 = await g1.json()
assert g1.status == 200, d1
assert "missing" in set(d1.get("tags", []))
# Restore same bytes and revert mtime to the original value
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(data)
# set both atime and mtime in ns to ensure exact match
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
# Fast pass should clear 'missing' without a scan
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
d2 = await g2.json()
assert g2.status == 200, d2
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_fastpass_removes_stale_state_row_no_missing(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
Hashed asset with two states:
- delete one file
- run fast pass only
Expect:
- asset stays healthy (no 'missing')
- stale AssetCacheState row for the deleted path is removed.
We verify this behaviorally by recreating the deleted path and running fast pass again:
a new *seed* AssetInfo is created, which proves the old state row was not reused.
"""
scope = f"stale-{uuid.uuid4().hex[:6]}"
name = "two_states.bin"
data = make_asset_bytes(name, 2048)
# Upload hashed asset at path1
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
base = comfy_tmp_base_dir / root / "unit-tests" / scope
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
p1 = base / a1_filename
assert p1.exists()
aid = a["id"]
h = a["asset_hash"]
# Create second state path2, seed+scan to dedupe into the same Asset
p2 = base / "copy" / name
p2.parent.mkdir(parents=True, exist_ok=True)
p2.write_bytes(data)
await trigger_sync_seed_assets(http, api_base)
await run_scan_and_wait(root)
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
p1.unlink()
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
d1 = await g1.json()
assert g1.status == 200, d1
assert "missing" not in set(d1.get("tags", []))
# Recreate path1 and run fast pass again.
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
p1.write_bytes(data)
await trigger_sync_seed_assets(http, api_base)
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}"},
) as rl:
bl = await rl.json()
assert rl.status == 200, bl
items = bl.get("assets", [])
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
assert h in hashes
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
# Asset identity still healthy
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
assert rh.status == 200

316
tests-assets/test_crud.py Normal file
View File

@ -0,0 +1,316 @@
import asyncio
import uuid
from pathlib import Path
import aiohttp
import pytest
from conftest import get_asset_filename, trigger_sync_seed_assets
@pytest.mark.asyncio
async def test_create_from_hash_success(
http: aiohttp.ClientSession, api_base: str, seeded_asset: dict
):
h = seeded_asset["asset_hash"]
payload = {
"hash": h,
"name": "from_hash_ok.safetensors",
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
"user_metadata": {"k": "v"},
}
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r1:
b1 = await r1.json()
assert r1.status == 201, b1
assert b1["asset_hash"] == h
assert b1["created_new"] is False
aid = b1["id"]
# Calling again with the same name should return the same AssetInfo id
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2:
b2 = await r2.json()
assert r2.status == 201, b2
assert b2["id"] == aid
@pytest.mark.asyncio
async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# GET detail
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
detail = await rg.json()
assert rg.status == 200, detail
assert detail["id"] == aid
assert "user_metadata" in detail
assert "filename" in detail["user_metadata"]
# DELETE
async with http.delete(f"{api_base}/api/assets/{aid}") as rd:
assert rd.status == 204
# GET again -> 404
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
body = await rg2.json()
assert rg2.status == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
@pytest.mark.asyncio
async def test_delete_upon_reference_count(
http: aiohttp.ClientSession, api_base: str, seeded_asset: dict
):
# Create a second reference to the same asset via from-hash
src_hash = seeded_asset["asset_hash"]
payload = {
"hash": src_hash,
"name": "unit_ref_copy.safetensors",
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
"user_metadata": {"note": "copy"},
}
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2:
copy = await r2.json()
assert r2.status == 201, copy
assert copy["asset_hash"] == src_hash
assert copy["created_new"] is False
# Delete original reference -> asset identity must remain
aid1 = seeded_asset["id"]
async with http.delete(f"{api_base}/api/assets/{aid1}") as rd1:
assert rd1.status == 204
async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh1:
assert rh1.status == 200 # identity still present
# Delete the last reference with default semantics -> identity and cached files removed
aid2 = copy["id"]
async with http.delete(f"{api_base}/api/assets/{aid2}") as rd2:
assert rd2.status == 204
async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh2:
assert rh2.status == 404 # orphan content removed
@pytest.mark.asyncio
async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
payload = {
"name": "unit_1_renamed.safetensors",
"tags": ["models", "checkpoints", "unit-tests", "beta"],
"user_metadata": {"purpose": "updated", "epoch": 2},
}
async with http.put(f"{api_base}/api/assets/{aid}", json=payload) as ru:
body = await ru.json()
assert ru.status == 200, body
assert body["name"] == payload["name"]
assert "beta" in body["tags"]
assert body["user_metadata"]["purpose"] == "updated"
# filename should still be present and normalized by server
assert "filename" in body["user_metadata"]
@pytest.mark.asyncio
async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
h = seeded_asset["asset_hash"]
# Existing
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh1:
assert rh1.status == 200
# Non-existent
async with http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}") as rh2:
assert rh2.status == 404
@pytest.mark.asyncio
async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str):
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
# aiohttp exposes an empty body for HEAD, so validate status and that there is no payload.
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh:
assert rh.status == 400
body = await rh.read()
assert body == b""
@pytest.mark.asyncio
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
bogus = str(uuid.uuid4())
async with http.delete(f"{api_base}/api/assets/{bogus}") as r:
body = await r.json()
assert r.status == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
@pytest.mark.asyncio
async def test_create_from_hash_invalids(http: aiohttp.ClientSession, api_base: str):
# Bad hash algorithm
bad = {
"hash": "sha256:" + "0" * 64,
"name": "x.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
async with http.post(f"{api_base}/api/assets/from-hash", json=bad) as r1:
b1 = await r1.json()
assert r1.status == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Invalid JSON body
async with http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}") as r2:
b2 = await r2.json()
assert r2.status == 400
assert b2["error"]["code"] == "INVALID_JSON"
@pytest.mark.asyncio
async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str):
# All endpoints should be not found, as we UUID regex directly in the route definition.
bad_id = "not-a-uuid"
async with http.get(f"{api_base}/api/assets/{bad_id}") as r1:
assert r1.status == 404
async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3:
assert r3.status == 404
@pytest.mark.asyncio
async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_concurrent_delete_same_asset_info_single_204(
root: str,
http: aiohttp.ClientSession,
api_base: str,
asset_factory,
make_asset_bytes,
):
"""
Many concurrent DELETE for the same AssetInfo should result in:
- exactly one 204 No Content (the one that actually deleted)
- all others 404 Not Found (row already gone)
"""
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
name = "to_delete.bin"
data = make_asset_bytes(name, 1536)
created = await asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = created["id"]
# Hit the same endpoint N times in parallel.
n_tests = 4
url = f"{api_base}/api/assets/{aid}?delete_content=false"
tasks = [asyncio.create_task(http.delete(url)) for _ in range(n_tests)]
responses = await asyncio.gather(*tasks)
statuses = [r.status for r in responses]
# Drain bodies to close connections (optional but avoids warnings).
await asyncio.gather(*[r.read() for r in responses])
# Exactly one actual delete, the rest must be 404
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
# The resource must be gone.
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
assert rg.status == 404
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_metadata_filename_is_set_for_seed_asset_without_hash(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
):
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
name = "seed_filename.bin"
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
base.mkdir(parents=True, exist_ok=True)
fp = base / name
fp.write_bytes(b"Z" * 2048)
await trigger_sync_seed_assets(http, api_base)
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
) as r1:
body = await r1.json()
assert r1.status == 200, body
matches = [a for a in body.get("assets", []) if a.get("name") == name]
assert matches, "Seed asset should be visible after sync"
assert matches[0].get("asset_hash") is None # still a seed
aid = matches[0]["id"]
async with http.get(f"{api_base}/api/assets/{aid}") as r2:
detail = await r2.json()
assert r2.status == 200, detail
filename = (detail.get("user_metadata") or {}).get("filename")
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
assert filename == expected, f"expected filename={expected}, got {filename!r}"
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_metadata_filename_computed_and_updated_on_retarget(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
run fast pass + scan -> filename updates to new relative path.
"""
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
name1 = "compute_metadata_filename.png"
name2 = "compute_changed_metadata_filename.png"
data = make_asset_bytes(name1, 2100)
# Upload into nested path a/b
a = await asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
aid = a["id"]
root_base = comfy_tmp_base_dir / root
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
assert p1.exists()
# filename at ingest should be the path relative to root
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
d1 = await g1.json()
assert g1.status == 200, d1
fn1 = d1["user_metadata"].get("filename")
assert fn1 == rel1
# Retarget: copy to x/<name2>, remove old, then sync+scan
p2 = root_base / "unit-tests" / scope / "x" / name2
p2.parent.mkdir(parents=True, exist_ok=True)
p2.write_bytes(data)
if p1.exists():
p1.unlink()
await trigger_sync_seed_assets(http, api_base) # seed the new path
await run_scan_and_wait(root) # verify/hash and reconcile
# filename should now point at x/<name2>
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
d2 = await g2.json()
assert g2.status == 200, d2
fn2 = d2["user_metadata"].get("filename")
assert fn2 == rel2

View File

@ -0,0 +1,168 @@
import asyncio
import uuid
from datetime import datetime
from pathlib import Path
from typing import Optional
import aiohttp
import pytest
from conftest import get_asset_filename, trigger_sync_seed_assets
@pytest.mark.asyncio
async def test_download_attachment_and_inline(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# default attachment
async with http.get(f"{api_base}/api/assets/{aid}/content") as r1:
data = await r1.read()
assert r1.status == 200
cd = r1.headers.get("Content-Disposition", "")
assert "attachment" in cd
assert data and len(data) == 4096
# inline requested
async with http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline") as r2:
await r2.read()
assert r2.status == 200
cd2 = r2.headers.get("Content-Disposition", "")
assert "inline" in cd2
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_download_chooses_existing_state_and_updates_access_time(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
Hashed asset with two state paths: if the first one disappears,
GET /content still serves from the remaining path and bumps last_access_time.
"""
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
name = "first_existing_state.bin"
data = make_asset_bytes(name, 3072)
# Upload -> path1
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
assert path1.exists()
# Seed path2 by copying, then scan to dedupe into a second state
path2 = base / "alt" / name
path2.parent.mkdir(parents=True, exist_ok=True)
path2.write_bytes(data)
await trigger_sync_seed_assets(http, api_base)
await run_scan_and_wait(root)
# Remove path1 so server must fall back to path2
path1.unlink()
# last_access_time before
async with http.get(f"{api_base}/api/assets/{aid}") as rg0:
d0 = await rg0.json()
assert rg0.status == 200, d0
ts0 = d0.get("last_access_time")
await asyncio.sleep(0.05)
async with http.get(f"{api_base}/api/assets/{aid}/content") as r:
blob = await r.read()
assert r.status == 200
assert blob == data # must serve from the surviving state (same bytes)
async with http.get(f"{api_base}/api/assets/{aid}") as rg1:
d1 = await rg1.json()
assert rg1.status == 200, d1
ts1 = d1.get("last_access_time")
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
if not s:
return None
s = s[:-1] if s.endswith("Z") else s
return datetime.fromisoformat(s).timestamp()
t0 = _parse_iso8601(ts0)
t1 = _parse_iso8601(ts1)
assert t1 is not None
if t0 is not None:
assert t1 > t0
@pytest.mark.asyncio
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
async def test_download_missing_file_returns_404(
http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
):
# Remove the underlying file then attempt download.
# We initialize fixture without additional tags to know exactly the asset file path.
try:
aid = seeded_asset["id"]
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
detail = await rg.json()
assert rg.status == 200
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
assert abs_path.exists()
abs_path.unlink()
async with http.get(f"{api_base}/api/assets/{aid}/content") as r2:
assert r2.status == 404
body = await r2.json()
assert body["error"]["code"] == "FILE_NOT_FOUND"
finally:
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
await dr.read()
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_download_404_if_all_states_missing(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
name = "missing_all_states.bin"
data = make_asset_bytes(name, 2048)
# Upload -> path1
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
assert p1.exists()
# Seed a second state and dedupe
p2 = base / "copy" / name
p2.parent.mkdir(parents=True, exist_ok=True)
p2.write_bytes(data)
await trigger_sync_seed_assets(http, api_base)
await run_scan_and_wait(root)
# Remove first file -> download should still work via the second state
p1.unlink()
async with http.get(f"{api_base}/api/assets/{aid}/content") as ok1:
b1 = await ok1.read()
assert ok1.status == 200 and b1 == data
# Remove the last file -> download must 404
p2.unlink()
async with http.get(f"{api_base}/api/assets/{aid}/content") as r2:
body = await r2.json()
assert r2.status == 404
assert body["error"]["code"] == "FILE_NOT_FOUND"

View File

@ -0,0 +1,337 @@
import asyncio
import uuid
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_list_assets_paging_and_sort(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
for n in names:
await asset_factory(
n,
["models", "checkpoints", "unit-tests", "paging"],
{"epoch": 1},
make_asset_bytes(n, size=2048),
)
# name ascending for stable order
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
got1 = [a["name"] for a in b1["assets"]]
assert got1 == sorted(names)[:2]
assert b1["has_more"] is True
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
) as r2:
b2 = await r2.json()
assert r2.status == 200
got2 = [a["name"] for a in b2["assets"]]
assert got2 == sorted(names)[2:]
assert b2["has_more"] is False
@pytest.mark.asyncio
async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.ClientSession, api_base: str, asset_factory):
a = await asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = await asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
) as r:
body = await r.json()
assert r.status == 200
names = [x["name"] for x in body["assets"]]
assert a["name"] in names
assert b["name"] not in names
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests", "name_contains": "inc_"},
) as r2:
body2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in body2["assets"]]
assert a["name"] in names2
assert b["name"] in names2
async with http.get(
api_base + "/api/assets",
params={"include_tags": "non-existing-tag"},
) as r2:
body3 = await r2.json()
assert r2.status == 200
assert not body3["assets"]
@pytest.mark.asyncio
async def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-size"]
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
await asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
await asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
await asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
) as r1:
b1 = await r1.json()
names = [a["name"] for a in b1["assets"]]
assert names[:3] == [n1, n2, n3]
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
) as r2:
b2 = await r2.json()
names2 = [a["name"] for a in b2["assets"]]
assert names2[:3] == [n3, n2, n1]
@pytest.mark.asyncio
async def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
a1 = await asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
a2 = await asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
# Rename the second asset to bump updated_at
async with http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}) as rp:
upd = await rp.json()
assert rp.status == 200, upd
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
) as r:
body = await r.json()
assert r.status == 200
names = [x["name"] for x in body["assets"]]
assert names[0] == "upd_b_renamed.safetensors"
assert a1["name"] in names
@pytest.mark.asyncio
async def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-access"]
await asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
await asyncio.sleep(0.02)
a2 = await asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
# Touch last_access_time of b by downloading its content
await asyncio.sleep(0.02)
async with http.get(f"{api_base}/api/assets/{a2['id']}/content") as dl:
assert dl.status == 200
await dl.read()
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
) as r:
body = await r.json()
assert r.status == 200
names = [x["name"] for x in body["assets"]]
assert names[0] == a2["name"]
@pytest.mark.asyncio
async def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-include"]
a = await asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
await asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
# CSV + case-insensitive
async with http.get(
api_base + "/api/assets",
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [x["name"] for x in b1["assets"]]
assert a["name"] in names1
assert not any("beta" in x for x in names1)
# Repeated query params for include_tags
params_multi = [
("include_tags", "unit-tests"),
("include_tags", "lf-include"),
("include_tags", "alpha"),
]
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
b2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in b2["assets"]]
assert a["name"] in names2
assert not any("beta" in x for x in names2)
# Duplicates and spaces in CSV
async with http.get(
api_base + "/api/assets",
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [x["name"] for x in b3["assets"]]
assert a["name"] in names3
@pytest.mark.asyncio
async def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
a = await asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
await asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
# Exclude uppercase should work
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [x["name"] for x in b1["assets"]]
assert a["name"] in names1
# Repeated excludes with duplicates
params_multi = [
("include_tags", "unit-tests"),
("include_tags", "lf-exclude"),
("exclude_tags", "beta"),
("exclude_tags", "beta"),
]
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
b2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in b2["assets"]]
assert all("beta" not in x for x in names2)
@pytest.mark.asyncio
async def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-name"]
a1 = await asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
a2 = await asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [x["name"] for x in b1["assets"]]
assert a1["name"] in names1
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
) as r2:
b2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in b2["assets"]]
assert a1["name"] in names2
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [x["name"] for x in b3["assets"]]
assert a2["name"] in names3
@pytest.mark.asyncio
async def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
await asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
await asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
await asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
# Offset far beyond total
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
assert not b1["assets"]
assert b1["has_more"] is False
# Boundary large limit (<=500 is valid)
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
) as r2:
b2 = await r2.json()
assert r2.status == 200
assert len(b2["assets"]) == 3
assert b2["has_more"] is False
@pytest.mark.asyncio
async def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
async with http.get(api_base + "/api/assets", params={"offset": "-1"}) as r1:
b1 = await r1.json()
assert r1.status == 400
assert b1["error"]["code"] == "INVALID_QUERY"
async with http.get(api_base + "/api/assets", params={"limit": "abc"}) as r2:
b2 = await r2.json()
assert r2.status == 400
assert b2["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str):
# limit too small
async with http.get(api_base + "/api/assets", params={"limit": "0"}) as r1:
b1 = await r1.json()
assert r1.status == 400
assert b1["error"]["code"] == "INVALID_QUERY"
# bad metadata JSON
async with http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}) as r2:
b2 = await r2.json()
assert r2.status == 400
assert b2["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_list_assets_name_contains_literal_underscore(
http,
api_base,
asset_factory,
make_asset_bytes,
):
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
We create:
- foo_bar.safetensors (should match)
- fooxbar.safetensors (must NOT match if '_' is escaped)
- foobar.safetensors (must NOT match)
"""
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
tags = ["models", "checkpoints", "unit-tests", scope]
a = await asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
b = await asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
c = await asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
) as r:
body = await r.json()
assert r.status == 200, body
names = [x["name"] for x in body["assets"]]
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
assert body["total"] == 1

View File

@ -0,0 +1,387 @@
import json
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_meta_and_across_keys_and_types(
http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes
):
name = "mf_and_mix.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
await asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
# All keys must match (AND semantics)
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-and",
"metadata_filter": json.dumps(f_ok),
},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names = [a["name"] for a in b1["assets"]]
assert name in names
# One key mismatched -> no result
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-and",
"metadata_filter": json.dumps(f_bad),
},
) as r2:
b2 = await r2.json()
assert r2.status == 200
assert not b2["assets"]
@pytest.mark.asyncio
async def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
name = "mf_types.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
meta = {"epoch": 1, "active": True}
await asset_factory(name, tags, meta, make_asset_bytes(name))
# int filter matches numeric
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"epoch": 1}),
},
) as r1:
b1 = await r1.json()
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
# string "1" must NOT match numeric 1
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"epoch": "1"}),
},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and not b2["assets"]
# bool True matches, string "true" must NOT match
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"active": True}),
},
) as r3:
b3 = await r3.json()
assert r3.status == 200 and any(a["name"] == name for a in b3["assets"])
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"active": "true"}),
},
) as r4:
b4 = await r4.json()
assert r4.status == 200 and not b4["assets"]
@pytest.mark.asyncio
async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_scalars.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
meta = {"flags": ["red", "green"]}
await asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
# Any-of should match because "green" is present
filt_ok = {"flags": ["blue", "green"]}
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
) as r1:
b1 = await r1.json()
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
# None of provided flags present -> no match
filt_miss = {"flags": ["blue", "yellow"]}
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and not b2["assets"]
# Duplicates in list should not break matching
filt_dup = {"flags": ["green", "green", "green"]}
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
) as r3:
b3 = await r3.json()
assert r3.status == 200 and any(a["name"] == name for a in b3["assets"])
@pytest.mark.asyncio
async def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
http, api_base, asset_factory, make_asset_bytes
):
# a1: key missing; a2: explicit null; a3: concrete value
t = ["models", "checkpoints", "unit-tests", "mf-none"]
a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
a2 = await asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
a3 = await asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
# Filter {maybe: None} must match a1 and a2, not a3
filt = {"maybe": None}
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
got = [a["name"] for a in b1["assets"]]
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
# Any-of with None should include missing/null plus value matches
filt_any = {"maybe": [None, "x"]}
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
) as r2:
b2 = await r2.json()
assert r2.status == 200
got2 = [a["name"] for a in b2["assets"]]
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
@pytest.mark.asyncio
async def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
name = "mf_nested_json.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
await asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
# Exact JSON object equality (same structure)
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-nested",
"metadata_filter": json.dumps({"config": cfg}),
},
) as r1:
b1 = await r1.json()
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
# Different JSON object should not match
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-nested",
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and not b2["assets"]
@pytest.mark.asyncio
async def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_objects.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
await asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
# Any-of for list of objects should match when one element equals the filter object
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-objlist",
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
},
) as r1:
b1 = await r1.json()
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
# Non-matching object -> no match
async with http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-objlist",
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and not b2["assets"]
@pytest.mark.asyncio
async def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
name = "mf_keys_unicode.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
meta = {
"weird.key": "v1",
"path/like": 7,
"with:colon": True,
"ключ": "значение",
"emoji": "🐍",
}
await asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
# Match all the special keys
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
) as r1:
b1 = await r1.json()
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
# Unicode key match
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and any(a["name"] == name for a in b2["assets"])
@pytest.mark.asyncio
async def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
a0 = await asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
a1 = await asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
# count == 0 must match only a0
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [a["name"] for a in b1["assets"]]
assert a0["name"] in names1 and a1["name"] not in names1
# Any-of list of booleans: True matches second asset
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
@pytest.mark.asyncio
async def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
name = "mf_mixed_list.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
meta = {"mix": ["1", 1, True, None]}
await asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
# Should match because 1 is present
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
) as r1:
b1 = await r1.json()
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
# Should NOT match for False
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and not b2["assets"]
@pytest.mark.asyncio
async def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
# Use a unique scope tag to avoid interference
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
x = await asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
y = await asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
# Filtering by unknown key with None should return both (missing key OR null)
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names = {a["name"] for a in b1["assets"]}
assert x["name"] in names and y["name"] in names
# Filtering by unknown key with concrete value should return none
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
) as r2:
b2 = await r2.json()
assert r2.status == 200 and not b2["assets"]
@pytest.mark.asyncio
async def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
# alpha matches epoch=1; beta has epoch=2
a = await asset_factory(
"mf_tag_alpha.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
{"epoch": 1},
make_asset_bytes("alpha"),
)
b = await asset_factory(
"mf_tag_beta.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
{"epoch": 2},
make_asset_bytes("beta"),
)
params = {
"include_tags": "unit-tests,mf-tag,alpha",
"exclude_tags": "beta",
"name_contains": "mf_tag_",
"metadata_filter": json.dumps({"epoch": 1}),
}
async with http.get(api_base + "/api/assets", params=params) as r:
body = await r.json()
assert r.status == 200
names = [x["name"] for x in body["assets"]]
assert a["name"] in names
assert b["name"] not in names
@pytest.mark.asyncio
async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
# Three assets in same scope with different sizes and a common filter key
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
await asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
await asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
# Sort by size ascending with paging
q = {
"include_tags": "unit-tests,mf-sort",
"metadata_filter": json.dumps({"group": "g"}),
"sort": "size", "order": "asc", "limit": "2",
}
async with http.get(api_base + "/api/assets", params=q) as r1:
b1 = await r1.json()
assert r1.status == 200
got1 = [a["name"] for a in b1["assets"]]
assert got1 == [n1, n2]
assert b1["has_more"] is True
q2 = {**q, "offset": "2"}
async with http.get(api_base + "/api/assets", params=q2) as r2:
b2 = await r2.json()
assert r2.status == 200
got2 = [a["name"] for a in b2["assets"]]
assert got2 == [n3]
assert b2["has_more"] is False

510
tests-assets/test_scans.py Normal file
View File

@ -0,0 +1,510 @@
import asyncio
import os
import uuid
from pathlib import Path
import aiohttp
import pytest
from conftest import get_asset_filename, trigger_sync_seed_assets
def _base_for(root: str, comfy_tmp_base_dir: Path) -> Path:
assert root in ("input", "output")
return comfy_tmp_base_dir / root
def _mkbytes(label: str, size: int) -> bytes:
seed = sum(label.encode("utf-8")) % 251
return bytes((i * 31 + seed) % 256 for i in range(size))
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_scan_schedule_idempotent_while_running(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
run_scan_and_wait,
):
"""Idempotent schedule while running."""
scope = f"idem-{uuid.uuid4().hex[:6]}"
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
base.mkdir(parents=True, exist_ok=True)
# Create several seed files (non-zero) to ensure the scan runs long enough
for i in range(8):
(base / f"f{i}.bin").write_bytes(_mkbytes(f"{scope}-{i}", 2 * 1024 * 1024)) # ~2 MiB each
# Seed -> states with hash=NULL
await trigger_sync_seed_assets(http, api_base)
# Schedule once
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r1:
b1 = await r1.json()
assert r1.status == 202, b1
scans1 = {s["root"]: s for s in b1.get("scans", [])}
s1 = scans1.get(root)
assert s1 and s1["status"] in {"scheduled", "running"}
sid1 = s1["scan_id"]
# Schedule again immediately — must return the same scan entry (no new worker)
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r2:
b2 = await r2.json()
assert r2.status == 202, b2
scans2 = {s["root"]: s for s in b2.get("scans", [])}
s2 = scans2.get(root)
assert s2 and s2["scan_id"] == sid1
# Filtered GET must show exactly one scan for this root
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs:
bs = await gs.json()
assert gs.status == 200, bs
scans = bs.get("scans", [])
assert len(scans) == 1 and scans[0]["scan_id"] == sid1
# Let it finish to avoid cross-test interference
await run_scan_and_wait(root)
@pytest.mark.asyncio
async def test_scan_status_filter_by_root_and_file_errors(
http,
api_base: str,
comfy_tmp_base_dir: Path,
run_scan_and_wait,
asset_factory,
):
"""Filtering get scan status by root (schedule for both input and output) + file_errors presence."""
# Create one hashed asset in input under a dir we will chmod to 000 to force PermissionError in reconcile stage
in_scope = f"filter-in-{uuid.uuid4().hex[:6]}"
protected_dir = _base_for("input", comfy_tmp_base_dir) / "unit-tests" / in_scope / "deny"
protected_dir.mkdir(parents=True, exist_ok=True)
name_in = "protected.bin"
data = b"A" * 4096
await asset_factory(name_in, ["input", "unit-tests", in_scope, "deny"], {}, data)
try:
os.chmod(protected_dir, 0o000)
# Also schedule a scan for output root (no errors there)
out_scope = f"filter-out-{uuid.uuid4().hex[:6]}"
out_dir = _base_for("output", comfy_tmp_base_dir) / "unit-tests" / out_scope
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir / "ok.bin").write_bytes(b"B" * 1024)
await trigger_sync_seed_assets(http, api_base) # seed output file
# Schedule both roots
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["input"]}) as r_in:
assert r_in.status == 202
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["output"]}) as r_out:
assert r_out.status == 202
# Wait both to complete, input last (we want its errors)
await run_scan_and_wait("output")
await run_scan_and_wait("input")
# Filter by root=input: only input scan listed and must have file_errors
async with http.get(api_base + "/api/assets/scan", params={"root": "input"}) as gs:
body = await gs.json()
assert gs.status == 200, body
scans = body.get("scans", [])
assert len(scans) == 1
errs = scans[0].get("file_errors", [])
# Must contain at least one error with a message
assert errs and any(e.get("message") for e in errs)
finally:
# Restore perms so cleanup can remove files/dirs
try:
os.chmod(protected_dir, 0o755)
except Exception:
pass
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
@pytest.mark.skipif(os.name == "nt", reason="Permission-based file_errors are unreliable on Windows")
async def test_scan_records_file_errors_permission_denied(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
run_scan_and_wait,
):
"""file_errors recording (permission denied) for input/output"""
scope = f"errs-{uuid.uuid4().hex[:6]}"
deny_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "deny"
deny_dir.mkdir(parents=True, exist_ok=True)
name = "deny.bin"
a1 = await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048)
asset_filename = get_asset_filename(a1["asset_hash"], ".bin")
try:
os.chmod(deny_dir, 0o000)
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r:
assert r.status == 202
await run_scan_and_wait(root)
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs:
body = await gs.json()
assert gs.status == 200, body
scans = body.get("scans", [])
assert len(scans) == 1
errs = scans[0].get("file_errors", [])
# Should contain at least one PermissionError-like record
assert errs
assert any(e.get("path", "").endswith(asset_filename) and e.get("message") for e in errs)
finally:
try:
os.chmod(deny_dir, 0o755)
except Exception:
pass
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_missing_tag_created_and_visible_in_tags(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
):
"""Missing tag appears in tags list and increments count (input/output)"""
# Baseline count of 'missing' tag (may be absent)
async with http.get(api_base + "/api/tags", params={"limit": "1000"}) as r0:
t0 = await r0.json()
assert r0.status == 200, t0
byname = {t["name"]: t for t in t0.get("tags", [])}
old_count = int(byname.get("missing", {}).get("count", 0))
scope = f"miss-{uuid.uuid4().hex[:6]}"
name = "missing_me.bin"
created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Y" * 4096)
# Remove the only file and trigger fast pass
p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin")
assert p.exists()
p.unlink()
await trigger_sync_seed_assets(http, api_base)
# Asset has 'missing' tag
async with http.get(f"{api_base}/api/assets/{created['id']}") as g1:
d1 = await g1.json()
assert g1.status == 200, d1
assert "missing" in set(d1.get("tags", []))
# Tag list now contains 'missing' with increased count
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
t1 = await r1.json()
assert r1.status == 200, t1
byname1 = {t["name"]: t for t in t1.get("tags", [])}
assert "missing" in byname1
assert int(byname1["missing"]["count"]) >= old_count + 1
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_missing_reapplies_after_manual_removal(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
):
"""Manual removal of 'missing' does not block automatic re-apply (input/output)"""
scope = f"reapply-{uuid.uuid4().hex[:6]}"
name = "reapply.bin"
created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Z" * 1024)
# Make it missing
p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin")
p.unlink()
await trigger_sync_seed_assets(http, api_base)
# Remove the 'missing' tag manually
async with http.delete(f"{api_base}/api/assets/{created['id']}/tags", json={"tags": ["missing"]}) as rdel:
b = await rdel.json()
assert rdel.status == 200, b
assert "missing" in set(b.get("removed", []))
# Next sync must re-add it
await trigger_sync_seed_assets(http, api_base)
async with http.get(f"{api_base}/api/assets/{created['id']}") as g2:
d2 = await g2.json()
assert g2.status == 200, d2
assert "missing" in set(d2.get("tags", []))
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_delete_one_asset_info_of_missing_asset_keeps_identity(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
):
"""Delete one AssetInfo of a missing asset while another exists (input/output)"""
scope = f"twoinfos-{uuid.uuid4().hex[:6]}"
name = "twoinfos.bin"
a1 = await asset_factory(name, [root, "unit-tests", scope], {}, b"W" * 2048)
# Second AssetInfo for the same content under same root (different name to avoid collision)
a2 = await asset_factory("copy_" + name, [root, "unit-tests", scope], {}, b"W" * 2048)
# Remove file of the first (both point to the same Asset, but we know on-disk path name for a1)
p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin")
p1.unlink()
await trigger_sync_seed_assets(http, api_base)
# Both infos should be marked missing
async with http.get(f"{api_base}/api/assets/{a1['id']}") as g1:
d1 = await g1.json()
assert "missing" in set(d1.get("tags", []))
async with http.get(f"{api_base}/api/assets/{a2['id']}") as g2:
d2 = await g2.json()
assert "missing" in set(d2.get("tags", []))
# Delete one info
async with http.delete(f"{api_base}/api/assets/{a1['id']}") as rd:
assert rd.status == 204
# Asset identity still exists (by hash)
h = a1["asset_hash"]
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
assert rh.status == 200
# Remaining info still reflects 'missing'
async with http.get(f"{api_base}/api/assets/{a2['id']}") as g3:
d3 = await g3.json()
assert g3.status == 200 and "missing" in set(d3.get("tags", []))
@pytest.mark.asyncio
@pytest.mark.parametrize("keep_root", ["input", "output"])
async def test_delete_last_asset_info_false_keeps_asset_and_states_multiroot(
keep_root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
make_asset_bytes,
asset_factory,
):
"""Delete last AssetInfo with delete_content_if_orphan=false keeps asset and the underlying on-disk content."""
other_root = "output" if keep_root == "input" else "input"
scope = f"delfalse-{uuid.uuid4().hex[:6]}"
data = make_asset_bytes(scope, 3072)
# First upload creates the physical file
a1 = await asset_factory("keep1.bin", [keep_root, "unit-tests", scope], {}, data)
# Second upload (other root) is deduped to the same content; no new file on disk
a2 = await asset_factory("keep2.bin", [other_root, "unit-tests", scope], {}, data)
h = a1["asset_hash"]
p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(h, ".bin")
# De-dup semantics: only the first physical file exists
assert p1.exists(), "Expected the first physical file to exist"
# Delete both AssetInfos; keep content on the very last delete
async with http.delete(f"{api_base}/api/assets/{a2['id']}") as rfirst:
assert rfirst.status == 204
async with http.delete(f"{api_base}/api/assets/{a1['id']}?delete_content=false") as rlast:
assert rlast.status == 204
# Asset identity remains and physical content is still present
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
assert rh.status == 200
assert p1.exists(), "Content file should remain after keep-content delete"
# Cleanup: re-create a reference by hash and then delete to purge content
payload = {
"hash": h,
"name": "cleanup.bin",
"tags": [keep_root, "unit-tests", scope, "cleanup"],
"user_metadata": {},
}
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as rfh:
ref = await rfh.json()
assert rfh.status == 201, ref
cid = ref["id"]
async with http.delete(f"{api_base}/api/assets/{cid}") as rdel:
assert rdel.status == 204
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_sync_seed_ignores_zero_byte_files(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
):
scope = f"zero-{uuid.uuid4().hex[:6]}"
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
base.mkdir(parents=True, exist_ok=True)
z = base / "empty.dat"
z.write_bytes(b"") # zero bytes
await trigger_sync_seed_assets(http, api_base)
# No AssetInfo created for this zero-byte file
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests," + scope, "name_contains": "empty.dat"},
) as r:
body = await r.json()
assert r.status == 200, body
assert not [a for a in body.get("assets", []) if a.get("name") == "empty.dat"]
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_sync_seed_idempotency(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
):
scope = f"idemseed-{uuid.uuid4().hex[:6]}"
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
base.mkdir(parents=True, exist_ok=True)
files = [f"f{i}.dat" for i in range(3)]
for i, n in enumerate(files):
(base / n).write_bytes(_mkbytes(n, 1500 + i * 10))
await trigger_sync_seed_assets(http, api_base)
async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r1:
b1 = await r1.json()
assert r1.status == 200, b1
c1 = len(b1.get("assets", []))
# Seed again -> count must stay the same
await trigger_sync_seed_assets(http, api_base)
async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r2:
b2 = await r2.json()
assert r2.status == 200, b2
c2 = len(b2.get("assets", []))
assert c1 == c2 == len(files)
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_sync_seed_nested_dirs_produce_parent_tags(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
):
scope = f"nest-{uuid.uuid4().hex[:6]}"
# nested: unit-tests / scope / a / b / c / deep.txt
deep_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "a" / "b" / "c"
deep_dir.mkdir(parents=True, exist_ok=True)
(deep_dir / "deep.txt").write_bytes(scope.encode())
await trigger_sync_seed_assets(http, api_base)
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": "deep.txt"},
) as r:
body = await r.json()
assert r.status == 200, body
assets = body.get("assets", [])
assert assets, "seeded asset not found"
tags = set(assets[0].get("tags", []))
# Must include all parent parts as tags + the root
for must in {root, "unit-tests", scope, "a", "b", "c"}:
assert must in tags
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_concurrent_seed_hashing_same_file_no_dupes(
root: str,
http: aiohttp.ClientSession,
api_base: str,
comfy_tmp_base_dir: Path,
run_scan_and_wait,
):
"""
Create a single seed file, then schedule two scans back-to-back.
Expect: no duplicate AssetInfos, a single hashed asset, and no scan failure.
"""
scope = f"conc-seed-{uuid.uuid4().hex[:6]}"
name = "seed_concurrent.bin"
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
base.mkdir(parents=True, exist_ok=True)
(base / name).write_bytes(b"Z" * 2048)
await trigger_sync_seed_assets(http, api_base)
s1, s2 = await asyncio.gather(
http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}),
http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}),
)
await s1.read()
await s2.read()
assert s1.status in (200, 202)
assert s2.status in (200, 202)
await run_scan_and_wait(root)
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
) as r:
b = await r.json()
assert r.status == 200, b
matches = [a for a in b.get("assets", []) if a.get("name") == name]
assert len(matches) == 1
assert matches[0].get("asset_hash"), "Seed should have been hashed into an Asset"
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_cache_state_retarget_on_content_change_asset_info_stays(
root: str,
http,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
Start with hashed H1 (AssetInfo A1). Replace file bytes on disk to become H2.
After scan: AssetCacheState points to H2; A1 still references H1; downloading A1 -> 404.
"""
scope = f"retarget-{uuid.uuid4().hex[:6]}"
name = "content_change.bin"
d1 = make_asset_bytes("v1-" + scope, 2048)
a1 = await asset_factory(name, [root, "unit-tests", scope], {}, d1)
aid = a1["id"]
h1 = a1["asset_hash"]
p = comfy_tmp_base_dir / root / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin")
assert p.exists()
# Change the bytes in place to force a new content hash (H2)
d2 = make_asset_bytes("v2-" + scope, 3072)
p.write_bytes(d2)
# Scan to verify and retarget the state; reconcilers run after scan
await run_scan_and_wait(root)
# AssetInfo still on the old content identity (H1)
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
g = await rg.json()
assert rg.status == 200, g
assert g.get("asset_hash") == h1
# Download must fail until a state exists for H1 (we removed the only one by retarget)
async with http.get(f"{api_base}/api/assets/{aid}/content") as dl:
body = await dl.json()
assert dl.status == 404, body
assert body["error"]["code"] == "FILE_NOT_FOUND"

228
tests-assets/test_tags.py Normal file
View File

@ -0,0 +1,228 @@
import json
import uuid
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# Only used tags before we add anything new from this test cycle
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
# We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names
assert "checkpoints" in used_names
# Prefix filter should refine the list
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [t["name"] for t in b3["tags"]]
assert "unit-tests" in names3
assert "models" not in names3 # filtered out by prefix
# Order by name ascending should be stable
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4:
b4 = await r4.json()
assert r4.status == 200
names4 = [t["name"] for t in b4["tags"]]
assert names4 == sorted(names4)
@pytest.mark.asyncio
async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
# Baseline: system tags exist when include_zero (default) is true
async with http.get(api_base + "/api/tags", params={"limit": "500"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
assert "models" in names and "checkpoints" in names
# Create a short-lived asset under input with a unique custom tag
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
name = "tag_seed.bin"
_asset = await asset_factory(
name,
["input", "unit-tests", scope, custom_tag],
{},
make_asset_bytes(name, 512),
)
# While the asset exists, the custom tag must appear when include_zero=false
async with http.get(
api_base + "/api/tags",
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
) as r2:
body2 = await r2.json()
assert r2.status == 200
used_names = [t["name"] for t in body2["tags"]]
assert custom_tag in used_names
# Delete the asset so the tag usage drops to zero
async with http.delete(f"{api_base}/api/assets/{_asset['id']}") as rd:
assert rd.status == 204
# Now the custom tag must not be returned when include_zero=false
async with http.get(
api_base + "/api/tags",
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
) as r3:
body3 = await r3.json()
assert r3.status == 200
names_after = [t["name"] for t in body3["tags"]]
assert custom_tag not in names_after
assert not names_after # filtered view should be empty now
@pytest.mark.asyncio
async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# Add tags with duplicates and mixed case
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
b1 = await r1.json()
assert r1.status == 200, b1
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
g = await rg.json()
assert rg.status == 200
tags_now = set(g["tags"])
assert {"newtag", "beta"}.issubset(tags_now)
# Remove a tag and a non-existent tag
payload_del = {"tags": ["newtag", "does-not-exist"]}
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
b2 = await r2.json()
assert r2.status == 200
assert set(b2["removed"]) == {"newtag"}
assert set(b2["not_present"]) == {"does-not-exist"}
# Verify remaining tags after deletion
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
g2 = await rg2.json()
assert rg2.status == 200
tags_later = set(g2["tags"])
assert "newtag" not in tags_later
assert "beta" in tags_later # still present
@pytest.mark.asyncio
async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
h = seeded_asset["asset_hash"]
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}) as r_add:
add_body = await r_add.json()
assert r_add.status == 200, add_body
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
payload = {
"hash": h,
"name": "order_only_bbb.safetensors",
"tags": ["input", "unit-tests", "orderbbb"],
"user_metadata": {},
}
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r_copy:
copy_body = await r_copy.json()
assert r_copy.status == 201, copy_body
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
# because it has higher usage (2 vs 1).
async with http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}) as r1:
b1 = await r1.json()
assert r1.status == 200, b1
names1 = [t["name"] for t in b1["tags"]]
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
# Both must be present within the prefix subset
assert "orderaaa" in names1 and "orderbbb" in names1
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
assert counts1["orderbbb"] >= counts1["orderaaa"]
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
assert names1.index("orderbbb") < names1.index("orderaaa")
# 2) name_asc: lexical order should flip the relative order
async with http.get(
api_base + "/api/tags",
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
) as r2:
b2 = await r2.json()
assert r2.status == 200, b2
names2 = [t["name"] for t in b2["tags"]]
assert "orderaaa" in names2 and "orderbbb" in names2
assert names2.index("orderaaa") < names2.index("orderbbb")
# 3) invalid limit rejected (existing negative case retained)
async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r3:
b3 = await r3.json()
assert r3.status == 400
assert b3["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# Add with empty list
async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}) as r1:
b1 = await r1.json()
assert r1.status == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Remove with wrong type
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}) as r2:
b2 = await r2.json()
assert r2.status == 400
assert b2["error"]["code"] == "INVALID_BODY"
# metadata_filter provided as JSON array should be rejected (must be object)
async with http.get(
api_base + "/api/assets",
params={"metadata_filter": json.dumps([{"x": 1}])},
) as r3:
b3 = await r3.json()
assert r3.status == 400
assert b3["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_tags_prefix_treats_underscore_literal(
http,
api_base,
asset_factory,
make_asset_bytes,
):
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
base = f"pref_{uuid.uuid4().hex[:6]}"
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r:
body = await r.json()
assert r.status == 200, body
names = [t["name"] for t in body["tags"]]
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
assert body["total"] == 1

View File

@ -0,0 +1,325 @@
import asyncio
import json
import uuid
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"}
data = make_asset_bytes(name)
form1 = aiohttp.FormData()
form1.add_field("file", data, filename=name, content_type="application/octet-stream")
form1.add_field("tags", json.dumps(tags))
form1.add_field("name", name)
form1.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form1) as r1:
a1 = await r1.json()
assert r1.status == 201, a1
assert a1["created_new"] is True
# Second upload with the same data and name should return created_new == False and the same asset
form2 = aiohttp.FormData()
form2.add_field("file", data, filename=name, content_type="application/octet-stream")
form2.add_field("tags", json.dumps(tags))
form2.add_field("name", name)
form2.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form2) as r2:
a2 = await r2.json()
assert r2.status == 200, a2
assert a2["created_new"] is False
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["id"] == a1["id"] # old reference
# Third upload with the same data but new name should return created_new == False and the new AssetReference
form3 = aiohttp.FormData()
form3.add_field("file", data, filename=name, content_type="application/octet-stream")
form3.add_field("tags", json.dumps(tags))
form3.add_field("name", name + "_d")
form3.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form3) as r2:
a3 = await r2.json()
assert r2.status == 200, a3
assert a3["created_new"] is False
assert a3["asset_hash"] == a1["asset_hash"]
assert a3["id"] != a1["id"] # old reference
@pytest.mark.asyncio
async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str):
# Seed a small file first
name = "fastpath_seed.safetensors"
tags = ["models", "checkpoints", "unit-tests"]
meta = {}
form1 = aiohttp.FormData()
form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream")
form1.add_field("tags", json.dumps(tags))
form1.add_field("name", name)
form1.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form1) as r1:
b1 = await r1.json()
assert r1.status == 201, b1
h = b1["asset_hash"]
# Now POST /api/assets with only hash and no file
form2 = aiohttp.FormData(default_to_multipart=True)
form2.add_field("hash", h)
form2.add_field("tags", json.dumps(tags))
form2.add_field("name", "fastpath_copy.safetensors")
form2.add_field("user_metadata", json.dumps({"purpose": "copy"}))
async with http.post(api_base + "/api/assets", data=form2) as r2:
b2 = await r2.json()
assert r2.status == 200, b2 # fast path returns 200 with created_new == False
assert b2["created_new"] is False
assert b2["asset_hash"] == h
@pytest.mark.asyncio
async def test_upload_fastpath_with_known_hash_and_file(
http: aiohttp.ClientSession, api_base: str
):
# Seed
form1 = aiohttp.FormData()
form1.add_field("file", b"C" * 128, filename="seed.safetensors", content_type="application/octet-stream")
form1.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"]))
form1.add_field("name", "seed.safetensors")
form1.add_field("user_metadata", json.dumps({}))
async with http.post(api_base + "/api/assets", data=form1) as r1:
b1 = await r1.json()
assert r1.status == 201, b1
h = b1["asset_hash"]
# Send both file and hash of existing content -> server must drain file and create from hash (200)
form2 = aiohttp.FormData()
form2.add_field("file", b"ignored" * 10, filename="ignored.bin", content_type="application/octet-stream")
form2.add_field("hash", h)
form2.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"]))
form2.add_field("name", "copy_from_hash.safetensors")
form2.add_field("user_metadata", json.dumps({}))
async with http.post(api_base + "/api/assets", data=form2) as r2:
b2 = await r2.json()
assert r2.status == 200, b2
assert b2["created_new"] is False
assert b2["asset_hash"] == h
@pytest.mark.asyncio
async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"B" * 256, filename="merge.safetensors", content_type="application/octet-stream")
form.add_field("tags", "models,checkpoints") # CSV
form.add_field("tags", json.dumps(["unit-tests", "alpha"])) # JSON array in second field
form.add_field("name", "merge.safetensors")
form.add_field("user_metadata", json.dumps({"u": 1}))
async with http.post(api_base + "/api/assets", data=form) as r1:
created = await r1.json()
assert r1.status in (200, 201), created
aid = created["id"]
# Verify all tags are present on the resource
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
detail = await rg.json()
assert rg.status == 200, detail
tags = set(detail["tags"])
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_concurrent_upload_identical_bytes_different_names(
root: str,
http: aiohttp.ClientSession,
api_base: str,
make_asset_bytes,
):
"""
Two concurrent uploads of identical bytes but different names.
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
"""
scope = f"concupload-{uuid.uuid4().hex[:6]}"
name1, name2 = "cu_a.bin", "cu_b.bin"
data = make_asset_bytes("concurrent", 4096)
tags = [root, "unit-tests", scope]
def _form(name: str) -> aiohttp.FormData:
f = aiohttp.FormData()
f.add_field("file", data, filename=name, content_type="application/octet-stream")
f.add_field("tags", json.dumps(tags))
f.add_field("name", name)
f.add_field("user_metadata", json.dumps({}))
return f
r1, r2 = await asyncio.gather(
http.post(api_base + "/api/assets", data=_form(name1)),
http.post(api_base + "/api/assets", data=_form(name2)),
)
b1, b2 = await r1.json(), await r2.json()
assert r1.status in (200, 201), b1
assert r2.status in (200, 201), b2
assert b1["asset_hash"] == b2["asset_hash"]
assert b1["id"] != b2["id"]
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
assert created_flags == [False, True]
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
) as rl:
bl = await rl.json()
assert rl.status == 200, bl
names = [a["name"] for a in bl.get("assets", [])]
assert set([name1, name2]).issubset(names)
@pytest.mark.asyncio
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
payload = {
"hash": "blake3:" + "0" * 64,
"name": "nonexistent.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r:
body = await r.json()
assert r.status == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
@pytest.mark.asyncio
async def test_upload_zero_byte_rejected(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"", filename="empty.safetensors", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"]))
form.add_field("name", "empty.safetensors")
form.add_field("user_metadata", json.dumps({}))
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "EMPTY_UPLOAD"
@pytest.mark.asyncio
async def test_upload_invalid_root_tag_rejected(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 64, filename="badroot.bin", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["not-a-root", "whatever"]))
form.add_field("name", "badroot.bin")
form.add_field("user_metadata", json.dumps({}))
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.asyncio
async def test_upload_user_metadata_must_be_json(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 128, filename="badmeta.bin", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"]))
form.add_field("name", "badmeta.bin")
form.add_field("user_metadata", "{not json}") # invalid
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.asyncio
async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str):
async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r:
body = await r.json()
assert r.status == 415
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
@pytest.mark.asyncio
async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"]))
form.add_field("name", "x.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "MISSING_FILE"
@pytest.mark.asyncio
async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"]))
form.add_field("name", "m.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
assert body["error"]["message"].startswith("unknown models category")
@pytest.mark.asyncio
async def test_upload_models_requires_category(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 64, filename="nocat.safetensors", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["models"])) # missing category
form.add_field("name", "nocat.safetensors")
form.add_field("user_metadata", json.dumps({}))
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.asyncio
async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream")
# '..' should be rejected by destination resolver
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]))
form.add_field("name", "evil.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
@pytest.mark.asyncio
@pytest.mark.parametrize("root", ["input", "output"])
async def test_duplicate_upload_same_display_name_does_not_clobber(
root: str,
http,
api_base: str,
asset_factory,
make_asset_bytes,
):
"""
Two uploads use the same tags and the same display name but different bytes.
With hash-based filenames, they must NOT overwrite each other. Both assets
remain accessible and serve their original content.
"""
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
display_name = "same_display.bin"
d1 = make_asset_bytes(scope + "-v1", 1536)
d2 = make_asset_bytes(scope + "-v2", 2048)
tags = [root, "unit-tests", scope]
first = await asset_factory(display_name, tags, {}, d1)
second = await asset_factory(display_name, tags, {}, d2)
assert first["id"] != second["id"]
assert first["asset_hash"] != second["asset_hash"] # different content
assert first["name"] == second["name"] == display_name
# Both must be independently retrievable
async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1:
b1 = await r1.read()
assert r1.status == 200
assert b1 == d1
async with http.get(f"{api_base}/api/assets/{second['id']}/content") as r2:
b2 = await r2.read()
assert r2.status == 200
assert b2 == d2

View File

@ -1,62 +0,0 @@
import pytest
import base64
import json
import struct
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
img = Image.new('RGB', (100, 100), 'white')
img_byte_arr = BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
safetensors_file = tmp_path / "test_model.safetensors"
header_bytes = json.dumps({
"__metadata__": {
"ssmd_cover_images": json.dumps([img_b64])
}
}).encode('utf-8')
length_bytes = struct.pack('<Q', len(header_bytes))
with open(safetensors_file, 'wb') as f:
f.write(length_bytes)
f.write(header_bytes)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
# Verify response
assert response.status == 200
assert response.content_type == 'image/webp'
# Verify the response contains valid image data
img_bytes = BytesIO(await response.read())
img = Image.open(img_bytes)
assert img.format
assert img.format.lower() == 'webp'
# Clean up
img.close()