diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 40871e644..fc3cde9c9 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -65,14 +65,36 @@ jobs: arch: amd64 preset: 'CUDA 12' install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' cuda-version: '12.8' flags: '' + runner_dir: 'cuda_v12' + - os: windows + arch: amd64 + preset: 'CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' + flags: '' + runner_dir: 'cuda_v13' - os: windows arch: amd64 preset: 'ROCm 6' install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe rocm-version: '6.2' flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + runner_dir: '' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: @@ -96,7 +118,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - $subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} + $subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait } @@ -138,7 +160,7 @@ jobs: run: | Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' - cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} + cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}" cmake --build --parallel --preset "${{ matrix.preset }}" cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 env: @@ -232,7 +254,7 @@ jobs: case "$COMPONENT" in bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; - lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; + lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4d8cf773c..e470540a2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -46,7 +46,7 @@ jobs: include: - preset: CPU - preset: CUDA - container: nvidia/cuda:12.8.1-devel-ubuntu22.04 + container: nvidia/cuda:13.0.0-devel-ubuntu22.04 flags: '-DCMAKE_CUDA_ARCHITECTURES=87' - preset: ROCm container: rocm/dev-ubuntu-22.04:6.1.2 @@ -78,8 +78,17 @@ jobs: include: - preset: CPU - preset: CUDA - install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' @@ -102,7 +111,8 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait + $subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} + Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait } $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path diff --git a/.gitignore b/.gitignore index 3a2af0bd1..eabf94c28 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dist build .cache +.gocache *.exe .idea test_data diff --git a/CMakeLists.txt b/CMakeLists.txt index d62c8f99f..a03afc6aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") endif() set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama) -set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama) +set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) @@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) install(TARGETS ggml-cuda RUNTIME_DEPENDENCIES - DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR} + DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} PRE_INCLUDE_REGEXES cublas cublasLt cudart PRE_EXCLUDE_REGEXES ".*" RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA @@ -98,14 +98,17 @@ check_language(HIP) if(CMAKE_HIP_COMPILER) set(HIP_PLATFORM "amd") - find_package(hip REQUIRED) if(NOT AMDGPU_TARGETS) + find_package(hip REQUIRED) list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$") - elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) + endif() + + if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX}) endif() if(AMDGPU_TARGETS) + find_package(hip REQUIRED) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) if (WIN32) @@ -114,7 +117,6 @@ if(CMAKE_HIP_COMPILER) target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM) - set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip RUNTIME_DEPENDENCY_SET rocm RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP @@ -125,13 +127,13 @@ if(CMAKE_HIP_COMPILER) PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf PRE_EXCLUDE_REGEXES ".*" POST_EXCLUDE_REGEXES "system32" - RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP - LIBRARY DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP ) foreach(HIP_LIB_BIN_INSTALL_DIR IN ITEMS ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}) if(EXISTS ${HIP_LIB_BIN_INSTALL_DIR}/rocblas) - install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP) + install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP) break() endif() endforeach() diff --git a/CMakePresets.json b/CMakePresets.json index ab2cfe9d6..51190c719 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -18,6 +18,14 @@ "name": "CUDA", "inherits": [ "Default" ] }, + { + "name": "CUDA 11", + "inherits": [ "CUDA" ], + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual", + "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" + } + }, { "name": "CUDA 12", "inherits": [ "CUDA" ], @@ -26,6 +34,14 @@ "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" } }, + { + "name": "CUDA 13", + "inherits": [ "CUDA" ], + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_FLAGS": "-t 2" + } + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], @@ -72,11 +88,21 @@ "configurePreset": "CUDA", "targets": [ "ggml-cuda" ] }, + { + "name": "CUDA 11", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 11" + }, { "name": "CUDA 12", "inherits": [ "CUDA" ], "configurePreset": "CUDA 12" }, + { + "name": "CUDA 13", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 13" + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], diff --git a/Dockerfile b/Dockerfile index 0dc3c1267..e3ab29af1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ # vim: filetype=dockerfile ARG FLAVOR=${TARGETARCH} +ARG PARALLEL=8 ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 @@ -34,26 +35,51 @@ ENV LDFLAGS=-s FROM base AS cpu RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CPU' \ - && cmake --build --parallel --preset 'CPU' \ - && cmake --install build --component CPU --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'CPU' \ + && cmake --install build --component CPU --strip --parallel ${PARALLEL} + +FROM base AS cuda-11 +ARG CUDA11VERSION=11.8 +RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} +ENV PATH=/usr/local/cuda-11/bin:$PATH +ARG PARALLEL +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \ + && cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS cuda-12 ARG CUDA12VERSION=12.8 RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} ENV PATH=/usr/local/cuda-12/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'CUDA 12' \ - && cmake --build --parallel --preset 'CUDA 12' \ - && cmake --install build --component CUDA --strip --parallel 8 + cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\ + && cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} + + +FROM base AS cuda-13 +ARG CUDA13VERSION=13.0 +RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} +ENV PATH=/usr/local/cuda-13/bin:$PATH +ARG PARALLEL +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \ + && cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} + FROM base AS rocm-6 ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'ROCm 6' \ - && cmake --build --parallel --preset 'ROCm 6' \ - && cmake --install build --component HIP --strip --parallel 8 + cmake --preset 'ROCm 6' -DOLLAMA_RUNNER_DIR="rocm" \ + && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ + && cmake --install build --component HIP --strip --parallel ${PARALLEL} FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 ARG CMAKEVERSION @@ -61,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \ && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 5' \ - && cmake --build --parallel --preset 'JetPack 5' \ - && cmake --install build --component CUDA --strip --parallel 8 + cmake --preset 'JetPack 5' -DOLLAMA_RUNNER_DIR="cuda_jetpack5" \ + && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 ARG CMAKEVERSION @@ -72,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \ && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 6' \ - && cmake --build --parallel --preset 'JetPack 6' \ - && cmake --install build --component CUDA --strip --parallel 8 + cmake --preset 'JetPack 6' -DOLLAMA_RUNNER_DIR="cuda_jetpack6" \ + && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS build WORKDIR /go/src/github.com/ollama/ollama @@ -92,12 +120,16 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 -COPY --from=cuda-12 dist/lib/ollama /lib/ollama +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ FROM --platform=linux/arm64 scratch AS arm64 -COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa -COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 -COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ +COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/ +COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/ FROM scratch AS rocm COPY --from=rocm-6 dist/lib/ollama /lib/ollama diff --git a/README.md b/README.md index 0680590f5..5962f5b28 100644 --- a/README.md +++ b/README.md @@ -413,6 +413,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.) - [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.) - [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models) +- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare) +- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads) ### Cloud @@ -541,6 +543,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama) - [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/)) - [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/)) +- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama) ### Mobile @@ -601,6 +604,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama) - [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies) - [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases) +- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network) ### Supported backends diff --git a/api/client.go b/api/client.go index 7cc2acb3d..0d4c97ba9 100644 --- a/api/client.go +++ b/api/client.go @@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error { return nil } + if resp.StatusCode == http.StatusUnauthorized { + authError := AuthorizationError{StatusCode: resp.StatusCode} + json.Unmarshal(body, &authError) + return authError + } + apiError := StatusError{StatusCode: resp.StatusCode} err := json.Unmarshal(body, &apiError) @@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f scanner.Buffer(scanBuf, maxBufferSize) for scanner.Scan() { var errorResponse struct { - Error string `json:"error,omitempty"` + Error string `json:"error,omitempty"` + SigninURL string `json:"signin_url,omitempty"` } bts := scanner.Bytes() @@ -222,7 +229,13 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f return fmt.Errorf("unmarshal: %w", err) } - if response.StatusCode >= http.StatusBadRequest { + if response.StatusCode == http.StatusUnauthorized { + return AuthorizationError{ + StatusCode: response.StatusCode, + Status: response.Status, + SigninURL: errorResponse.SigninURL, + } + } else if response.StatusCode >= http.StatusBadRequest { return StatusError{ StatusCode: response.StatusCode, Status: response.Status, @@ -428,3 +441,21 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } + +// Signout will signout a client for a local ollama server. +func (c *Client) Signout(ctx context.Context) error { + return c.do(ctx, http.MethodPost, "/api/signout", nil, nil) +} + +// Disconnect will disconnect an ollama instance from ollama.com. +func (c *Client) Disconnect(ctx context.Context, encodedKey string) error { + return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) +} + +func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { + var resp UserResponse + if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/api/types.go b/api/types.go index f8187316e..c726f7dd7 100644 --- a/api/types.go +++ b/api/types.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -36,6 +38,19 @@ func (e StatusError) Error() string { } } +type AuthorizationError struct { + StatusCode int + Status string + SigninURL string `json:"signin_url"` +} + +func (e AuthorizationError) Error() string { + if e.Status != "" { + return e.Status + } + return "something went wrong, please see the ollama server logs for details" +} + // ImageData represents the raw binary data of an image file. type ImageData []byte @@ -286,16 +301,23 @@ func mapToTypeScriptType(jsonType string) string { } } +type ToolFunctionParameters struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]ToolProperty `json:"properties"` +} + +func (t *ToolFunctionParameters) String() string { + bts, _ := json.Marshal(t) + return string(bts) +} + type ToolFunction struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]ToolProperty `json:"properties"` - } `json:"parameters"` + Name string `json:"name"` + Description string `json:"description"` + Parameters ToolFunctionParameters `json:"parameters"` } func (t *ToolFunction) String() string { @@ -306,13 +328,29 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` + // Model is the model name that generated the response. + Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + + // CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Message contains the message or part of a message from the model. + Message Message `json:"message"` + + // Done specifies if the response is complete. Done bool `json:"done"` + // DoneReason is the reason the model stopped generating text. + DoneReason string `json:"done_reason,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + Metrics } @@ -322,13 +360,6 @@ type DebugInfo struct { ImageCount int `json:"image_count,omitempty"` } -// DebugTemplateResponse is returned when _debug_render_only is set to true -type DebugTemplateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - DebugInfo DebugInfo `json:"_debug_info"` -} - type Metrics struct { TotalDuration time.Duration `json:"total_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` @@ -382,8 +413,12 @@ type EmbedRequest struct { // this request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Truncate truncates the input to fit the model's max sequence length. Truncate *bool `json:"truncate,omitempty"` + // Dimensions truncates the output embedding to the specified dimension. + Dimensions int `json:"dimensions,omitempty"` + // Options lists model-specific options. Options map[string]any `json:"options"` } @@ -421,18 +456,47 @@ type EmbeddingResponse struct { // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { - Model string `json:"model"` - Stream *bool `json:"stream,omitempty"` + // Model is the model name to create. + Model string `json:"model"` + + // Stream specifies whether the response is streaming; it is true by default. + Stream *bool `json:"stream,omitempty"` + + // Quantize is the quantization format for the model; leave blank to not change the quantization level. Quantize string `json:"quantize,omitempty"` - From string `json:"from,omitempty"` - Files map[string]string `json:"files,omitempty"` - Adapters map[string]string `json:"adapters,omitempty"` - Template string `json:"template,omitempty"` - License any `json:"license,omitempty"` - System string `json:"system,omitempty"` - Parameters map[string]any `json:"parameters,omitempty"` - Messages []Message `json:"messages,omitempty"` + // From is the name of the model or file to use as the source. + From string `json:"from,omitempty"` + + // RemoteHost is the URL of the upstream ollama API for the model (if any). + RemoteHost string `json:"remote_host,omitempty"` + + // Files is a map of files include when creating the model. + Files map[string]string `json:"files,omitempty"` + + // Adapters is a map of LoRA adapters to include when creating the model. + Adapters map[string]string `json:"adapters,omitempty"` + + // Template is the template used when constructing a request to the model. + Template string `json:"template,omitempty"` + + // License is a string or list of strings for licenses. + License any `json:"license,omitempty"` + + // System is the system prompt for the model. + System string `json:"system,omitempty"` + + // Parameters is a map of hyper-parameters which are applied to the model. + Parameters map[string]any `json:"parameters,omitempty"` + + // Messages is a list of messages added to the model before chat and generation requests. + Messages []Message `json:"messages,omitempty"` + + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + + // Info is a map of additional information for the model + Info map[string]any `json:"info,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -470,8 +534,12 @@ type ShowResponse struct { Parameters string `json:"parameters,omitempty"` Template string `json:"template,omitempty"` System string `json:"system,omitempty"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` Details ModelDetails `json:"details,omitempty"` Messages []Message `json:"messages,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"` Tensors []Tensor `json:"tensors,omitempty"` @@ -530,12 +598,14 @@ type ProcessResponse struct { // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { - Name string `json:"name"` - Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` - Details ModelDetails `json:"details,omitempty"` + Name string `json:"name"` + Model string `json:"model"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` } // ProcessModelResponse is a single model description in [ProcessResponse]. @@ -559,6 +629,12 @@ type GenerateResponse struct { // Model is the model name that generated the response. Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + // CreatedAt is the timestamp of the response. CreatedAt time.Time `json:"created_at"` @@ -582,6 +658,8 @@ type GenerateResponse struct { Metrics ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` } // ModelDetails provides details about a model. @@ -594,6 +672,18 @@ type ModelDetails struct { QuantizationLevel string `json:"quantization_level"` } +// UserResponse provides information about a user. +type UserResponse struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Bio string `json:"bio,omitempty"` + AvatarURL string `json:"avatarurl,omitempty"` + FirstName string `json:"firstname,omitempty"` + LastName string `json:"lastname,omitempty"` + Plan string `json:"plan,omitempty"` +} + // Tensor describes the metadata for a given tensor. type Tensor struct { Name string `json:"name"` @@ -883,7 +973,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { if t < 0 { d.Duration = time.Duration(math.MaxInt64) } else { - d.Duration = time.Duration(int(t) * int(time.Second)) + d.Duration = time.Duration(t * float64(time.Second)) } case string: d.Duration, err = time.ParseDuration(t) diff --git a/api/types_test.go b/api/types_test.go index 841853808..5393b4623 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { req string exp *Duration }{ + { + name: "Unset", + req: `{ }`, + exp: nil, + }, { name: "Positive Integer", req: `{ "keep_alive": 42 }`, @@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { { name: "Positive Float", req: `{ "keep_alive": 42.5 }`, - exp: &Duration{42 * time.Second}, + exp: &Duration{42500 * time.Millisecond}, }, { name: "Positive Integer String", @@ -436,3 +441,50 @@ func TestThinking_UnmarshalJSON(t *testing.T) { }) } } + +func TestToolFunctionParameters_String(t *testing.T) { + tests := []struct { + name string + params ToolFunctionParameters + expected string + }{ + { + name: "simple object with string property", + params: ToolFunctionParameters{ + Type: "object", + Required: []string{"name"}, + Properties: map[string]ToolProperty{ + "name": { + Type: PropertyType{"string"}, + Description: "The name of the person", + }, + }, + }, + expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`, + }, + { + name: "marshal failure returns empty string", + params: ToolFunctionParameters{ + Type: "object", + Defs: func() any { + // Create a cycle that will cause json.Marshal to fail + type selfRef struct { + Self *selfRef + } + s := &selfRef{} + s.Self = s + return s + }(), + Properties: map[string]ToolProperty{}, + }, + expected: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.params.String() + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/auth/auth.go b/auth/auth.go index e1d854124..f820964e7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -18,21 +18,13 @@ import ( const defaultPrivateKey = "id_ed25519" -func keyPath() (string, error) { +func GetPublicKey() (string, error) { home, err := os.UserHomeDir() if err != nil { return "", err } - return filepath.Join(home, ".ollama", defaultPrivateKey), nil -} - -func GetPublicKey() (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err - } - + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) @@ -59,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - keyPath, err := keyPath() + home, err := os.UserHomeDir() if err != nil { return "", err } + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) diff --git a/cmd/cmd.go b/cmd/cmd.go index 8fe068655..369a27a48 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -47,6 +47,8 @@ import ( "github.com/ollama/ollama/version" ) +const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" + // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { if name == "" { @@ -56,10 +58,8 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) if err != nil { return } - for _, cap := range resp.Capabilities { - if cap == model.CapabilityThinking { - return - } + if slices.Contains(resp.Capabilities, model.CapabilityThinking) { + return } fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name) } @@ -288,7 +288,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { Think: opts.Think, } - return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) + return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error { + if r.RemoteModel != "" && opts.ShowConnect { + p.StopAndClear() + if strings.HasPrefix(r.RemoteHost, "https://ollama.com") { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel) + } else { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost) + } + } + return nil + }) } func StopHandler(cmd *cobra.Command, args []string) error { @@ -309,9 +319,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { interactive := true opts := runOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]any{}, + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]any{}, + ShowConnect: true, } format, err := cmd.Flags().GetString("format") @@ -369,6 +380,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { } prompts = append([]string{string(in)}, prompts...) + opts.ShowConnect = false opts.WordWrap = false interactive = false } @@ -435,6 +447,15 @@ func RunHandler(cmd *cobra.Command, args []string) error { if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + + if sErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, sErr.SigninURL) + } + return nil + } return err } @@ -455,6 +476,59 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +func SigninHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + user, err := client.Whoami(cmd.Context()) + if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You need to be signed in to Ollama to run Cloud models.") + fmt.Println() + + if aErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, aErr.SigninURL) + } + return nil + } + return err + } + + if user != nil && user.Name != "" { + fmt.Printf("You are already signed in as user '%s'\n", user.Name) + fmt.Println() + return nil + } + + return nil +} + +func SignoutHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + err = client.Signout(cmd.Context()) + if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You are not signed in to ollama.com") + fmt.Println() + return nil + } else { + return err + } + } + + fmt.Println("You have signed out of ollama.com") + fmt.Println() + return nil +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -466,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error { return err } + n := model.ParseName(args[0]) + if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") { + _, err := client.Whoami(cmd.Context()) + if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You need to be signed in to push models to ollama.com.") + fmt.Println() + + if aErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, aErr.SigninURL) + } + return nil + } + + return err + } + } + p := progress.NewProgress(os.Stderr) defer p.Stop() @@ -502,12 +595,12 @@ func PushHandler(cmd *cobra.Command, args []string) error { request := api.PushRequest{Name: args[0], Insecure: insecure} - n := model.ParseName(args[0]) if err := client.Push(cmd.Context(), &request, fn); err != nil { if spinner != nil { spinner.Stop() } - if strings.Contains(err.Error(), "access denied") { + errStr := strings.ToLower(err.Error()) + if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } return err @@ -541,7 +634,14 @@ func ListHandler(cmd *cobra.Command, args []string) error { for _, m := range models.Models { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { - data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) + var size string + if m.RemoteModel != "" { + size = "-" + } else { + size = format.HumanBytes(m.Size) + } + + data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")}) } } @@ -626,8 +726,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { KeepAlive: &api.Duration{Duration: 0}, } if err := loadOrUnloadModel(cmd, opts); err != nil { - if !strings.Contains(err.Error(), "not found") { - return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) + if !strings.Contains(strings.ToLower(err.Error()), "not found") { + fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0]) } } @@ -738,12 +838,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { } tableRender("Model", func() (rows [][]string) { + if resp.RemoteHost != "" { + rows = append(rows, []string{"", "Remote model", resp.RemoteModel}) + rows = append(rows, []string{"", "Remote URL", resp.RemoteHost}) + } + if resp.ModelInfo != nil { arch := resp.ModelInfo["general.architecture"].(string) rows = append(rows, []string{"", "architecture", arch}) - rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))}) - rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)}) - rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)}) + + var paramStr string + if resp.Details.ParameterSize != "" { + paramStr = resp.Details.ParameterSize + } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok { + if f, ok := v.(float64); ok { + paramStr = format.HumanNumber(uint64(f)) + } + } + rows = append(rows, []string{"", "parameters", paramStr}) + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } } else { rows = append(rows, []string{"", "architecture", resp.Details.Family}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) @@ -991,6 +1115,52 @@ type runOptions struct { KeepAlive *api.Duration Think *api.ThinkValue HideThinking bool + ShowConnect bool +} + +func (r runOptions) Copy() runOptions { + var messages []api.Message + if r.Messages != nil { + messages = make([]api.Message, len(r.Messages)) + copy(messages, r.Messages) + } + + var images []api.ImageData + if r.Images != nil { + images = make([]api.ImageData, len(r.Images)) + copy(images, r.Images) + } + + var opts map[string]any + if r.Options != nil { + opts = make(map[string]any, len(r.Options)) + for k, v := range r.Options { + opts[k] = v + } + } + + var think *api.ThinkValue + if r.Think != nil { + cThink := *r.Think + think = &cThink + } + + return runOptions{ + Model: r.Model, + ParentModel: r.ParentModel, + Prompt: r.Prompt, + Messages: messages, + WordWrap: r.WordWrap, + Format: r.Format, + System: r.System, + Images: images, + Options: opts, + MultiModal: r.MultiModal, + KeepAlive: r.KeepAlive, + Think: think, + HideThinking: r.HideThinking, + ShowConnect: r.ShowConnect, + } } type displayResponseState struct { @@ -1546,6 +1716,22 @@ func NewCLI() *cobra.Command { pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") + signinCmd := &cobra.Command{ + Use: "signin", + Short: "Sign in to ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SigninHandler, + } + + signoutCmd := &cobra.Command{ + Use: "signout", + Short: "Sign out from ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SignoutHandler, + } + listCmd := &cobra.Command{ Use: "list", Aliases: []string{"ls"}, @@ -1640,6 +1826,8 @@ func NewCLI() *cobra.Command { stopCmd, pullCmd, pushCmd, + signinCmd, + signoutCmd, listCmd, psCmd, copyCmd, diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index cf5fe7caa..a84272c8e 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,10 +3,12 @@ package cmd import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" "os" + "reflect" "strings" "testing" "time" @@ -304,6 +306,8 @@ func TestDeleteHandler(t *testing.T) { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusNotFound) + errPayload := `{"error":"model '%s' not found"}` + w.Write([]byte(fmt.Sprintf(errPayload, req.Name))) } return } @@ -346,7 +350,7 @@ func TestDeleteHandler(t *testing.T) { } err := DeleteHandler(cmd, []string{"test-model-not-found"}) - if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { + if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") { t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) } } @@ -488,9 +492,35 @@ func TestPushHandler(t *testing.T) { w.(http.Flusher).Flush() } }, + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + }, }, expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", }, + { + name: "not signed in push", + modelName: "notsignedin-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized", + "signin_url": "https://somethingsomething", + }) + if err != nil { + t.Fatal(err) + } + }, + }, + expectedOutput: "You need to be signed in to push", + }, { name: "unauthorized push", modelName: "unauthorized-model", @@ -499,12 +529,17 @@ func TestPushHandler(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) err := json.NewEncoder(w).Encode(map[string]string{ - "error": "access denied", + "error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}", }) if err != nil { t.Fatal(err) } }, + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + }, }, expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", }, @@ -522,6 +557,10 @@ func TestPushHandler(t *testing.T) { defer mockServer.Close() t.Setenv("OLLAMA_HOST", mockServer.URL) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + initializeKeypair() cmd := &cobra.Command{} cmd.Flags().Bool("insecure", false, "") @@ -557,7 +596,7 @@ func TestPushHandler(t *testing.T) { t.Errorf("expected no error, got %v", err) } if tt.expectedOutput != "" { - if got := string(stdout); got != tt.expectedOutput { + if got := string(stdout); !strings.Contains(got, tt.expectedOutput) { t.Errorf("expected output %q, got %q", tt.expectedOutput, got) } } @@ -915,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) { }) } } + +func TestRunOptions_Copy(t *testing.T) { + // Setup test data + originalKeepAlive := &api.Duration{Duration: 5 * time.Minute} + originalThink := &api.ThinkValue{Value: "test reasoning"} + + original := runOptions{ + Model: "test-model", + ParentModel: "parent-model", + Prompt: "test prompt", + Messages: []api.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi there"}, + }, + WordWrap: true, + Format: "json", + System: "system prompt", + Images: []api.ImageData{ + []byte("image1"), + []byte("image2"), + }, + Options: map[string]any{ + "temperature": 0.7, + "max_tokens": 1000, + "top_p": 0.9, + }, + MultiModal: true, + KeepAlive: originalKeepAlive, + Think: originalThink, + HideThinking: false, + ShowConnect: true, + } + + // Test the copy + copied := original.Copy() + + // Test 1: Verify the copy is not the same instance + if &copied == &original { + t.Error("Copy should return a different instance") + } + + // Test 2: Verify all fields are copied correctly + tests := []struct { + name string + got interface{} + want interface{} + }{ + {"Model", copied.Model, original.Model}, + {"ParentModel", copied.ParentModel, original.ParentModel}, + {"Prompt", copied.Prompt, original.Prompt}, + {"WordWrap", copied.WordWrap, original.WordWrap}, + {"Format", copied.Format, original.Format}, + {"System", copied.System, original.System}, + {"MultiModal", copied.MultiModal, original.MultiModal}, + {"HideThinking", copied.HideThinking, original.HideThinking}, + {"ShowConnect", copied.ShowConnect, original.ShowConnect}, + } + + for _, tt := range tests { + if !reflect.DeepEqual(tt.got, tt.want) { + t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want) + } + } + + // Test 3: Verify Messages slice is deeply copied + if len(copied.Messages) != len(original.Messages) { + t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages)) + } + + if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] { + t.Error("Messages should be different instances") + } + + // Modify original to verify independence + if len(original.Messages) > 0 { + originalContent := original.Messages[0].Content + original.Messages[0].Content = "modified" + if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" { + t.Error("Messages should be independent after copy") + } + // Restore for other tests + original.Messages[0].Content = originalContent + } + + // Test 4: Verify Images slice is deeply copied + if len(copied.Images) != len(original.Images) { + t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images)) + } + + if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] { + t.Error("Images should be different instances") + } + + // Modify original to verify independence + if len(original.Images) > 0 { + originalImage := original.Images[0] + original.Images[0] = []byte("modified") + if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" { + t.Error("Images should be independent after copy") + } + // Restore for other tests + original.Images[0] = originalImage + } + + // Test 5: Verify Options map is deeply copied + if len(copied.Options) != len(original.Options) { + t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options)) + } + + if len(copied.Options) > 0 && &copied.Options == &original.Options { + t.Error("Options map should be different instances") + } + + // Modify original to verify independence + if len(original.Options) > 0 { + originalTemp := original.Options["temperature"] + original.Options["temperature"] = 0.9 + if copied.Options["temperature"] == 0.9 { + t.Error("Options should be independent after copy") + } + // Restore for other tests + original.Options["temperature"] = originalTemp + } + + // Test 6: Verify KeepAlive pointer is copied (shallow copy) + if copied.KeepAlive != original.KeepAlive { + t.Error("KeepAlive pointer should be the same (shallow copy)") + } + + // Test 7: Verify Think pointer creates a new instance + if original.Think != nil && copied.Think == original.Think { + t.Error("Think should be a different instance") + } + + if original.Think != nil && copied.Think != nil { + if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) { + t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value) + } + } + + // Test 8: Test with zero values + zeroOriginal := runOptions{} + zeroCopy := zeroOriginal.Copy() + + if !reflect.DeepEqual(zeroCopy, zeroOriginal) { + fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy) + t.Error("Copy of zero value should equal original zero value") + } +} + +func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) { + // Test with empty slices and maps + original := runOptions{ + Messages: []api.Message{}, + Images: []api.ImageData{}, + Options: map[string]any{}, + } + + copied := original.Copy() + + if copied.Messages == nil { + t.Error("Empty Messages slice should remain empty, not nil") + } + + if copied.Images == nil { + t.Error("Empty Images slice should remain empty, not nil") + } + + if copied.Options == nil { + t.Error("Empty Options map should remain empty, not nil") + } + + if len(copied.Messages) != 0 { + t.Error("Empty Messages slice should remain empty") + } + + if len(copied.Images) != 0 { + t.Error("Empty Images slice should remain empty") + } + + if len(copied.Options) != 0 { + t.Error("Empty Options map should remain empty") + } +} + +func TestRunOptions_Copy_NilPointers(t *testing.T) { + // Test with nil pointers + original := runOptions{ + KeepAlive: nil, + Think: nil, + } + + copied := original.Copy() + + if copied.KeepAlive != nil { + t.Error("Nil KeepAlive should remain nil") + } + + if copied.Think != nil { + t.Error("Nil Think should remain nil") + } +} + +func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) { + tests := []struct { + name string + think *api.ThinkValue + }{ + {"nil Think", nil}, + {"bool true", &api.ThinkValue{Value: true}}, + {"bool false", &api.ThinkValue{Value: false}}, + {"string value", &api.ThinkValue{Value: "reasoning text"}}, + {"int value", &api.ThinkValue{Value: 42}}, + {"nil value", &api.ThinkValue{Value: nil}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original := runOptions{Think: tt.think} + copied := original.Copy() + + if tt.think == nil { + if copied.Think != nil { + t.Error("Nil Think should remain nil") + } + return + } + + if copied.Think == nil { + t.Error("Non-nil Think should not become nil") + return + } + + if copied.Think == original.Think { + t.Error("Think should be a different instance") + } + + if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) { + t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value) + } + }) + } +} + +func TestRunOptions_Copy_Independence(t *testing.T) { + // Test that modifications to original don't affect copy + originalThink := &api.ThinkValue{Value: "original"} + original := runOptions{ + Model: "original-model", + Messages: []api.Message{{Role: "user", Content: "original"}}, + Options: map[string]any{"key": "value"}, + Think: originalThink, + } + + copied := original.Copy() + + // Modify original + original.Model = "modified-model" + if len(original.Messages) > 0 { + original.Messages[0].Content = "modified" + } + original.Options["key"] = "modified" + if original.Think != nil { + original.Think.Value = "modified" + } + + // Verify copy is unchanged + if copied.Model == "modified-model" { + t.Error("Copy Model should not be affected by original modification") + } + + if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" { + t.Error("Copy Messages should not be affected by original modification") + } + + if copied.Options["key"] == "modified" { + t.Error("Copy Options should not be affected by original modification") + } + + if copied.Think != nil && copied.Think.Value == "modified" { + t.Error("Copy Think should not be affected by original modification") + } +} diff --git a/cmd/interactive.go b/cmd/interactive.go index e290d84ce..cf0aced14 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Println("Usage:\n /load ") continue } + origOpts := opts.Copy() + opts.Model = args[1] opts.Messages = []api.Message{} fmt.Printf("Loading model '%s'\n", opts.Model) opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet) if err != nil { + if strings.Contains(err.Error(), "not found") { + fmt.Printf("Couldn't find model '%s'\n", opts.Model) + opts = origOpts.Copy() + continue + } return err } if err := loadOrUnloadModel(cmd, &opts); err != nil { if strings.Contains(err.Error(), "not found") { - fmt.Printf("error: %v\n", err) + fmt.Printf("Couldn't find model '%s'\n", opts.Model) + opts = origOpts.Copy() continue } if strings.Contains(err.Error(), "does not support thinking") { diff --git a/convert/convert_bert.go b/convert/convert_bert.go index a9f4b8a77..6b0d0030a 100644 --- a/convert/convert_bert.go +++ b/convert/convert_bert.go @@ -28,6 +28,7 @@ type bertModel struct { LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"` + normalizeEmbeddings bool PoolingType uint32 } @@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error { var pooling string for _, m := range modules { - if m.Type == "sentence_transformers.models.Pooling" { + switch m.Type { + case "sentence_transformers.models.Pooling": pooling = m.Path - break + case "sentence_transformers.models.Normalize": + p.normalizeEmbeddings = true } } @@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV { kv["general.architecture"] = "bert" kv["bert.attention.causal"] = false kv["bert.pooling_type"] = p.PoolingType + kv["bert.normalize_embeddings"] = p.normalizeEmbeddings kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go index bd362169b..2048b18be 100644 --- a/convert/convert_gptoss.go +++ b/convert/convert_gptoss.go @@ -15,19 +15,24 @@ import ( type gptossModel struct { ModelParameters - HiddenLayers uint32 `json:"num_hidden_layers"` - HiddenSize uint32 `json:"hidden_size"` - IntermediateSize uint32 `json:"intermediate_size"` - AttentionHeads uint32 `json:"num_attention_heads"` - KeyValueHeads uint32 `json:"num_key_value_heads"` - HeadDim uint32 `json:"head_dim"` - Experts uint32 `json:"num_experts"` - ExpertsPerToken uint32 `json:"experts_per_token"` - RMSNormEpsilon float32 `json:"rms_norm_eps"` - InitialContextLength uint32 `json:"initial_context_length"` - RopeTheta float32 `json:"rope_theta"` - RopeScalingFactor float32 `json:"rope_scaling_factor"` - SlidingWindow uint32 `json:"sliding_window"` + HiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + AttentionHeads uint32 `json:"num_attention_heads"` + KeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + Experts uint32 `json:"num_experts"` + LocalExperts uint32 `json:"num_local_experts"` + ExpertsPerToken uint32 `json:"experts_per_token"` + RMSNormEpsilon float32 `json:"rms_norm_eps"` + InitialContextLength uint32 `json:"initial_context_length"` + RopeTheta float32 `json:"rope_theta"` + RopeScalingFactor float32 `json:"rope_scaling_factor"` + RopeScaling struct { + Factor float32 `json:"factor"` + } `json:"rope_scaling"` + SlidingWindow uint32 `json:"sliding_window"` } var _ ModelConverter = (*gptossModel)(nil) @@ -36,11 +41,11 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV { kv := m.ModelParameters.KV(t) kv["general.architecture"] = "gptoss" kv["general.file_type"] = uint32(4) - kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength)) + kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength))) kv["gptoss.block_count"] = m.HiddenLayers kv["gptoss.embedding_length"] = m.HiddenSize kv["gptoss.feed_forward_length"] = m.IntermediateSize - kv["gptoss.expert_count"] = m.Experts + kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts) kv["gptoss.expert_used_count"] = m.ExpertsPerToken kv["gptoss.attention.head_count"] = m.AttentionHeads kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads @@ -49,7 +54,7 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV { kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5) kv["gptoss.attention.sliding_window"] = m.SlidingWindow kv["gptoss.rope.freq_base"] = m.RopeTheta - kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor + kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor) kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|> kv["tokenizer.ggml.add_bos_token"] = false @@ -92,6 +97,11 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { for name, mxfp4 := range mxfp4s { dims := mxfp4.blocks.Shape() + + if !strings.HasSuffix(name, ".weight") { + name += ".weight" + } + out = append(out, &ggml.Tensor{ Name: name, Kind: uint32(ggml.TensorTypeMXFP4), @@ -104,25 +114,47 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { } func (m *gptossModel) Replacements() []string { - return []string{ - // noop replacements so other replacements will not be applied - ".blocks", ".blocks", - ".scales", ".scales", - // real replacements - "block", "blk", - "attn.norm", "attn_norm", - "attn.qkv", "attn_qkv", - "attn.sinks", "attn_sinks", - "attn.out", "attn_out", - "mlp.norm", "ffn_norm", - "mlp.gate", "ffn_gate_inp", - "mlp.mlp1_", "ffn_gate_up_exps.", - "mlp.mlp2_", "ffn_down_exps.", - "embedding", "token_embd", - "norm", "output_norm", - "unembedding", "output", - "scale", "weight", + var replacements []string + if m.MaxPositionEmbeddings > 0 { + // hf flavored model + replacements = []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_out", + "self_attn.sinks", "attn_sinks", + "post_attention_layernorm", "ffn_norm", + "mlp.router", "ffn_gate_inp", + "mlp.experts.gate_up_proj_", "ffn_gate_up_exps.", + "mlp.experts.down_proj_", "ffn_down_exps.", + "model.norm", "output_norm", + } + } else { + replacements = []string{ + // noop replacements so other replacements will not be applied + ".blocks", ".blocks", + ".scales", ".scales", + // real replacements + "block", "blk", + "attn.norm", "attn_norm", + "attn.qkv", "attn_qkv", + "attn.sinks", "attn_sinks", + "attn.out", "attn_out", + "mlp.norm", "ffn_norm", + "mlp.gate", "ffn_gate_inp", + "mlp.mlp1_", "ffn_gate_up_exps.", + "mlp.mlp2_", "ffn_down_exps.", + "embedding", "token_embd", + "norm", "output_norm", + "unembedding", "output", + "scale", "weight", + } } + return replacements } type mxfp4 struct { @@ -140,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { blocksDims[i] = int(d) } - var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes())) + bts := b.Bytes() + var tmp [16]byte + for i := 0; i < b.Len(); i += 16 { + for j := range 8 { + // transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc + a, b := bts[i+j], bts[i+j+8] + tmp[2*j+0] = (a & 0x0F) | (b << 4) + tmp[2*j+1] = (a >> 4) | (b & 0xF0) + } + + copy(bts[i:i+16], tmp[:]) + } + + var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts)) var s bytes.Buffer if _, err := m.scales.WriteTo(&s); err != nil { @@ -174,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { return 0, err } - return 0, nil + return int64(len(u8s)), nil } diff --git a/convert/reader.go b/convert/reader.go index 907d2a9ef..b3f7a8660 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 { const ( tensorKindFP32 uint32 = iota tensorKindFP16 - tensorKindMXFP4 = 4 tensorKindBF16 = 30 + tensorKindMXFP4 = 39 ) func (t tensorBase) Kind() uint32 { diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index ccc596732..eea0de2f5 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -96,7 +96,7 @@ type safetensor struct { func (st safetensor) Kind() uint32 { kind := st.tensorBase.Kind() - if st.dtype == "BF16" && kind != tensorKindFP32 { + if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 { kind = tensorKindBF16 } @@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { switch st.Kind() { case tensorKindFP32: - return 0, binary.Write(w, binary.LittleEndian, f32s) + return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s) case tensorKindFP16: f16s := make([]uint16, len(f32s)) for i := range f32s { f16s[i] = float16.Fromfloat32(f32s[i]).Bits() } - return 0, binary.Write(w, binary.LittleEndian, f16s) + return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s) case tensorKindBF16: u8s := bfloat16.EncodeFloat32(f32s) - return 0, binary.Write(w, binary.LittleEndian, u8s) + return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s) default: return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) } diff --git a/convert/reader_test.go b/convert/reader_test.go index 6dbe32a51..c3d094f10 100644 --- a/convert/reader_test.go +++ b/convert/reader_test.go @@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) { }) } } + +func TestSafetensorKind(t *testing.T) { + tests := []struct { + name string + st safetensor + expected uint32 + }{ + { + name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "BF16", + }, + expected: tensorKindBF16, + }, + { + name: "BF16 dtype with v. prefix should return base kind", + st: safetensor{ + tensorBase: &tensorBase{ + name: "v.weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "BF16", + }, + expected: tensorKindFP16, + }, + { + name: "BF16 dtype with FP32 base kind should return FP32", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10}, // will default to FP32 + }, + dtype: "BF16", + }, + expected: tensorKindFP32, + }, + { + name: "Non-BF16 dtype should return base kind", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "FP16", + }, + expected: tensorKindFP16, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.st.Kind() + if result != tt.expected { + t.Errorf("Kind() = %d, expected %d", result, tt.expected) + } + }) + } +} diff --git a/discover/amd_common.go b/discover/amd_common.go deleted file mode 100644 index 08834b22d..000000000 --- a/discover/amd_common.go +++ /dev/null @@ -1,83 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "errors" - "log/slog" - "os" - "path/filepath" - "runtime" - "strings" -) - -// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns -func rocmLibUsable(libDir string) bool { - slog.Debug("evaluating potential rocm lib dir " + libDir) - for _, g := range ROCmLibGlobs { - res, _ := filepath.Glob(filepath.Join(libDir, g)) - if len(res) == 0 { - return false - } - } - return true -} - -func GetSupportedGFX(libDir string) ([]string, error) { - var ret []string - files, err := filepath.Glob(filepath.Join(libDir, "rocblas", "library", "TensileLibrary_lazy_gfx*.dat")) - if err != nil { - return nil, err - } - for _, file := range files { - ret = append(ret, strings.TrimSuffix(strings.TrimPrefix(filepath.Base(file), "TensileLibrary_lazy_"), ".dat")) - } - return ret, nil -} - -func commonAMDValidateLibDir() (string, error) { - // Favor our bundled version - - // Installer payload location if we're running the installed binary - rocmTargetDir := filepath.Join(LibOllamaPath, "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) - return rocmTargetDir, nil - } - - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "bin") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return hipLibDir, nil - } - } - - // Scan the LD_LIBRARY_PATH or PATH - pathEnv := "LD_LIBRARY_PATH" - if runtime.GOOS == "windows" { - pathEnv = "PATH" - } - - paths := os.Getenv(pathEnv) - for _, path := range filepath.SplitList(paths) { - d, err := filepath.Abs(path) - if err != nil { - continue - } - if rocmLibUsable(d) { - return d, nil - } - } - - // Well known location(s) - for _, path := range RocmStandardLocations { - if rocmLibUsable(path) { - return path, nil - } - } - - return "", errors.New("no suitable rocm found, falling back to CPU") -} diff --git a/discover/amd_hip_windows.go b/discover/amd_hip_windows.go deleted file mode 100644 index bf19ef064..000000000 --- a/discover/amd_hip_windows.go +++ /dev/null @@ -1,147 +0,0 @@ -package discover - -import ( - "errors" - "fmt" - "log/slog" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - hipSuccess = 0 - hipErrorNoDevice = 100 -) - -type hipDevicePropMinimal struct { - Name [256]byte - unused1 [140]byte - GcnArchName [256]byte // gfx#### - iGPU int // Doesn't seem to actually report correctly - unused2 [128]byte -} - -// Wrap the amdhip64.dll library for GPU discovery -type HipLib struct { - dll windows.Handle - hipGetDeviceCount uintptr - hipGetDeviceProperties uintptr - hipMemGetInfo uintptr - hipSetDevice uintptr - hipDriverGetVersion uintptr -} - -func NewHipLib() (*HipLib, error) { - // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs - h, err := windows.LoadLibrary("amdhip64_6.dll") - if err != nil { - return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err) - } - hl := &HipLib{} - hl.dll = h - hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount") - if err != nil { - return nil, err - } - hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties") - if err != nil { - return nil, err - } - hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo") - if err != nil { - return nil, err - } - hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice") - if err != nil { - return nil, err - } - hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion") - if err != nil { - return nil, err - } - return hl, nil -} - -// The hip library only evaluates the ROCR_VISIBLE_DEVICES variable at startup -// so we have to unload/reset the library after we do our initial discovery -// to make sure our updates to that variable are processed by llama.cpp -func (hl *HipLib) Release() { - err := windows.FreeLibrary(hl.dll) - if err != nil { - slog.Warn("failed to unload amdhip64.dll", "error", err) - } - hl.dll = 0 -} - -func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) { - if hl.dll == 0 { - return 0, 0, errors.New("dll has been unloaded") - } - var version int - status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version))) - if status != hipSuccess { - return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err) - } - - slog.Debug("hipDriverGetVersion", "version", version) - driverMajor = version / 10000000 - driverMinor = (version - (driverMajor * 10000000)) / 100000 - - return driverMajor, driverMinor, nil -} - -func (hl *HipLib) HipGetDeviceCount() int { - if hl.dll == 0 { - slog.Error("dll has been unloaded") - return 0 - } - var count int - status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count))) - if status == hipErrorNoDevice { - slog.Info("AMD ROCm reports no devices found") - return 0 - } - if status != hipSuccess { - slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err) - } - return count -} - -func (hl *HipLib) HipSetDevice(device int) error { - if hl.dll == 0 { - return errors.New("dll has been unloaded") - } - status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device)) - if status != hipSuccess { - return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err) - } - return nil -} - -func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) { - if hl.dll == 0 { - return nil, errors.New("dll has been unloaded") - } - var props hipDevicePropMinimal - status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device)) - if status != hipSuccess { - return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err) - } - return &props, nil -} - -// free, total, err -func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) { - if hl.dll == 0 { - return 0, 0, errors.New("dll has been unloaded") - } - var totalMemory uint64 - var freeMemory uint64 - status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory))) - if status != hipSuccess { - return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err) - } - return freeMemory, totalMemory, nil -} diff --git a/discover/amd_linux.go b/discover/amd_linux.go deleted file mode 100644 index ebffbdf66..000000000 --- a/discover/amd_linux.go +++ /dev/null @@ -1,541 +0,0 @@ -package discover - -import ( - "bufio" - "errors" - "fmt" - "io" - "io/fs" - "log/slog" - "os" - "path/filepath" - "regexp" - "slices" - "sort" - "strconv" - "strings" - - "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" -) - -// Discovery logic for AMD/ROCm GPUs - -const ( - DriverVersionFile = "/sys/module/amdgpu/version" - AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/" - GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties" - - // Prefix with the node dir - GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line - - // Direct Rendering Manager sysfs location - DRMDeviceDirGlob = "/sys/class/drm/card*/device" - DRMTotalMemoryFile = "mem_info_vram_total" - DRMUsedMemoryFile = "mem_info_vram_used" - - // In hex; properties file is in decimal - DRMUniqueIDFile = "unique_id" - DRMVendorFile = "vendor" - DRMDeviceFile = "device" -) - -var ( - // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... - RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"} -) - -// Gather GPU information from the amdgpu driver if any supported GPUs are detected -// Only called once during bootstrap -func AMDGetGPUInfo() ([]RocmGPUInfo, error) { - resp := []RocmGPUInfo{} - if !AMDDetected() { - return resp, fmt.Errorf("AMD GPUs not detected") - } - - // Opportunistic logging of driver version to aid in troubleshooting - driverMajor, driverMinor, err := AMDDriverVersion() - if err != nil { - // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU - slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err) - } - - // Determine if the user has already pre-selected which GPUs to look at, then ignore the others - var visibleDevices []string - hipVD := envconfig.HipVisibleDevices() // zero based index only - rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID - gpuDO := envconfig.GpuDeviceOrdinal() // zero based index - switch { - case rocrVD != "": - visibleDevices = strings.Split(rocrVD, ",") - case hipVD != "": - visibleDevices = strings.Split(hipVD, ",") - case gpuDO != "": - visibleDevices = strings.Split(gpuDO, ",") - } - - gfxOverride := envconfig.HsaOverrideGfxVersion() - var supported []string - var libDir string - - // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract - // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) - matches, _ := filepath.Glob(GPUPropertiesFileGlob) - sort.Slice(matches, func(i, j int) bool { - // /sys/class/kfd/kfd/topology/nodes//properties - a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64) - if err != nil { - slog.Debug("parse err", "error", err, "match", matches[i]) - return false - } - b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64) - if err != nil { - slog.Debug("parse err", "error", err, "match", matches[i]) - return false - } - return a < b - }) - gpuCount := 0 - gpuOrdinalID := 0 - for _, match := range matches { - slog.Debug("evaluating amdgpu node " + match) - fp, err := os.Open(match) - if err != nil { - slog.Debug("failed to open sysfs node", "file", match, "error", err) - continue - } - defer fp.Close() - - scanner := bufio.NewScanner(fp) - isCPU := false - var major, minor, patch uint64 - var vendor, device, uniqueID uint64 - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs - if strings.HasPrefix(line, "gfx_target_version") { - ver := strings.Fields(line) - - // Detect CPUs - if len(ver) == 2 && ver[1] == "0" { - slog.Debug("detected CPU " + match) - isCPU = true - break - } - - if len(ver) != 2 || len(ver[1]) < 5 { - slog.Warn("malformed "+match, "gfx_target_version", line) - // If this winds up being a CPU, our offsets may be wrong - continue - } - l := len(ver[1]) - var err1, err2, err3 error - patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32) - minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32) - major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32) - if err1 != nil || err2 != nil || err3 != nil { - slog.Debug("malformed int " + line) - continue - } - } else if strings.HasPrefix(line, "vendor_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "vendor_id", line) - continue - } - vendor, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "vendor_id", line, "error", err) - } - } else if strings.HasPrefix(line, "device_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "device_id", line) - continue - } - device, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "device_id", line, "error", err) - } - } else if strings.HasPrefix(line, "unique_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "unique_id", line) - continue - } - uniqueID, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "unique_id", line, "error", err) - } - } - // TODO - any other properties we want to extract and record? - // vendor_id + device_id -> pci lookup for "Name" - // Other metrics that may help us understand relative performance between multiple GPUs - } - - // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers - // into consideration, so we instead map the device over to the DRM driver sysfs nodes which - // do reliably report VRAM usage. - - if isCPU { - continue - } - - // Skip over any GPUs that are masked - if major == 0 && minor == 0 && patch == 0 { - slog.Debug("skipping gpu with gfx000") - continue - } - - // Look up the memory for the current node - totalMemory := uint64(0) - usedMemory := uint64(0) - var usedFile string - mapping := []struct { - id uint64 - filename string - }{ - {vendor, DRMVendorFile}, - {device, DRMDeviceFile}, - {uniqueID, DRMUniqueIDFile}, // Not all devices will report this - } - slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID) - // Map over to DRM location to find the total/free memory - drmMatches, _ := filepath.Glob(DRMDeviceDirGlob) - for _, devDir := range drmMatches { - matched := true - for _, m := range mapping { - if m.id == 0 { - // Null ID means it didn't populate, so we can't use it to match - continue - } - filename := filepath.Join(devDir, m.filename) - buf, err := os.ReadFile(filename) - if err != nil { - slog.Debug("failed to read sysfs node", "file", filename, "error", err) - matched = false - break - } - // values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu - cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", filename, "error", err) - matched = false - break - } - if cmp != m.id { - matched = false - break - } - } - if !matched { - continue - } - - // Found the matching DRM directory - slog.Debug("matched", "amdgpu", match, "drm", devDir) - totalFile := filepath.Join(devDir, DRMTotalMemoryFile) - buf, err := os.ReadFile(totalFile) - if err != nil { - slog.Debug("failed to read sysfs node", "file", totalFile, "error", err) - break - } - totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err) - break - } - - usedFile = filepath.Join(devDir, DRMUsedMemoryFile) - usedMemory, err = getFreeMemory(usedFile) - if err != nil { - slog.Debug("failed to update used memory", "error", err) - } - break - } - - var name string - // TODO - PCI ID lookup - if vendor > 0 && device > 0 { - name = fmt.Sprintf("%04x:%04x", vendor, device) - } - - // Favor UUIDs if available to reduce possibility of getting the numeric IDs wrong - var ID string - if uniqueID != 0 { - ID = fmt.Sprintf("GPU-%016x", uniqueID) - } else { - ID = strconv.Itoa(gpuOrdinalID) - } - - gpuInfo := RocmGPUInfo{ - GpuInfo: GpuInfo{ - Library: "rocm", - memInfo: memInfo{ - TotalMemory: totalMemory, - FreeMemory: (totalMemory - usedMemory), - }, - ID: ID, - Name: name, - Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - DriverMajor: driverMajor, - DriverMinor: driverMinor, - }, - usedFilepath: usedFile, - index: gpuCount, - } - - // Keep track of numeric IDs based on valid GPUs - gpuCount += 1 - - // If the user wants to filter to a subset of devices, filter out if we aren't a match - if len(visibleDevices) > 0 { - include := false - for _, visible := range visibleDevices { - if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) { - include = true - break - } - } - if !include { - reason := "filtering out device per user request" - slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - continue - } - } - - // Ordinal IDs are based on the visible GPUs - gpuOrdinalID += 1 - - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if totalMemory < IGPUMemLimit { - reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - continue - } - minVer, err := strconv.Atoi(RocmComputeMajorMin) - if err != nil { - slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err) - } - if int(major) < minVer { - reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch) - slog.Warn(reason, "gpu", gpuInfo.ID) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - continue - } - - slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory)) - - // Final validation is gfx compatibility - load the library if we haven't already loaded it - // even if the user overrides, we still need to validate the library - if libDir == "" { - libDir, err = AMDValidateLibDir() - if err != nil { - err = fmt.Errorf("unable to verify rocm library: %w", err) - slog.Warn(err.Error()) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: err.Error(), - }) - return nil, err - } - } - gpuInfo.DependencyPath = []string{libDir} - - if gfxOverride == "" { - // Only load supported list once - if len(supported) == 0 { - supported, err = GetSupportedGFX(libDir) - if err != nil { - err = fmt.Errorf("failed to lookup supported GFX types: %w", err) - slog.Warn(err.Error()) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: err.Error(), - }) - return nil, err - } - slog.Debug("rocm supported GPUs", "types", supported) - } - gfx := gpuInfo.Compute - if !slices.Contains[[]string, string](supported, gfx) { - reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) - slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - // TODO - consider discrete markdown just for ROCM troubleshooting? - slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") - continue - } else { - slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx) - } - } else { - slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) - } - - // Check for env var workarounds - if name == "1002:687f" { // Vega RX 56 - gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"}) - } - - // The GPU has passed all the verification steps and is supported - resp = append(resp, gpuInfo) - } - if len(resp) == 0 { - err := fmt.Errorf("no compatible amdgpu devices detected") - slog.Info(err.Error()) - return nil, err - } - if err := verifyKFDDriverAccess(); err != nil { - err = fmt.Errorf("amdgpu devices detected but permission problems block access: %w", err) - slog.Error(err.Error()) - return nil, err - } - return resp, nil -} - -// Quick check for AMD driver so we can skip amdgpu discovery if not present -func AMDDetected() bool { - // Some driver versions (older?) don't have a version file, so just lookup the parent dir - sysfsDir := filepath.Dir(DriverVersionFile) - _, err := os.Stat(sysfsDir) - if errors.Is(err, os.ErrNotExist) { - slog.Debug("amdgpu driver not detected " + sysfsDir) - return false - } else if err != nil { - slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err) - return false - } - return true -} - -// Prefer to use host installed ROCm, as long as it meets our minimum requirements -// failing that, tell the user how to download it on their own -func AMDValidateLibDir() (string, error) { - libDir, err := commonAMDValidateLibDir() - if err == nil { - return libDir, nil - } - - // Well known ollama installer path - installedRocmDir := "/usr/share/ollama/lib/rocm" - if rocmLibUsable(installedRocmDir) { - return installedRocmDir, nil - } - - // If we still haven't found a usable rocm, the user will have to install it on their own - slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install") - return "", errors.New("no suitable rocm found, falling back to CPU") -} - -func AMDDriverVersion() (driverMajor, driverMinor int, err error) { - _, err = os.Stat(DriverVersionFile) - if err != nil { - return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err) - } - fp, err := os.Open(DriverVersionFile) - if err != nil { - return 0, 0, err - } - defer fp.Close() - verString, err := io.ReadAll(fp) - if err != nil { - return 0, 0, err - } - - pattern := `\A(\d+)\.(\d+).*` - regex := regexp.MustCompile(pattern) - match := regex.FindStringSubmatch(string(verString)) - if len(match) < 2 { - return 0, 0, fmt.Errorf("malformed version string %s", string(verString)) - } - driverMajor, err = strconv.Atoi(match[1]) - if err != nil { - return 0, 0, err - } - driverMinor, err = strconv.Atoi(match[2]) - if err != nil { - return 0, 0, err - } - return driverMajor, driverMinor, nil -} - -func (gpus RocmGPUInfoList) RefreshFreeMemory() error { - if len(gpus) == 0 { - return nil - } - for i := range gpus { - usedMemory, err := getFreeMemory(gpus[i].usedFilepath) - if err != nil { - return err - } - slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory)) - gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory - } - return nil -} - -func getFreeMemory(usedFile string) (uint64, error) { - buf, err := os.ReadFile(usedFile) - if err != nil { - return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err) - } - usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err) - return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err) - } - return usedMemory, nil -} - -func verifyKFDDriverAccess() error { - // Verify we have permissions - either running as root, or we have group access to the driver - fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0o666) - if err != nil { - if errors.Is(err, fs.ErrPermission) { - return fmt.Errorf("permissions not set up properly. Either run ollama as root, or add you user account to the render group. %w", err) - } else if errors.Is(err, fs.ErrNotExist) { - // Container runtime failure? - return fmt.Errorf("kfd driver not loaded. If running in a container, remember to include '--device /dev/kfd --device /dev/dri'") - } - return fmt.Errorf("failed to check permission on /dev/kfd: %w", err) - } - fd.Close() - return nil -} - -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - // TODO shouldn't happen if things are wired correctly... - slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) - continue - } - ids = append(ids, info.ID) - } - // There are 3 potential env vars to use to select GPUs. - // ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux - // GPU_DEVICE_ORDINAL supports numeric IDs only - // HIP_VISIBLE_DEVICES supports numeric IDs only - return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",") -} diff --git a/discover/amd_windows.go b/discover/amd_windows.go deleted file mode 100644 index 0659d12f8..000000000 --- a/discover/amd_windows.go +++ /dev/null @@ -1,218 +0,0 @@ -package discover - -import ( - "bytes" - "errors" - "fmt" - "log/slog" - "path/filepath" - "slices" - "strconv" - "strings" - - "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" -) - -const ( - - // TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true - iGPUName = "AMD Radeon(TM) Graphics" -) - -var ( - // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6 - RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob? -) - -// Only called once during bootstrap -func AMDGetGPUInfo() ([]RocmGPUInfo, error) { - resp := []RocmGPUInfo{} - hl, err := NewHipLib() - if err != nil { - slog.Debug(err.Error()) - return nil, err - } - defer hl.Release() - - driverMajor, driverMinor, err := hl.AMDDriverVersion() - if err != nil { - // For now this is benign, but we may eventually need to fail compatibility checks - slog.Debug("error looking up amd driver version", "error", err) - } - - // Note: the HIP library automatically handles subsetting to any *_VISIBLE_DEVICES the user specified - count := hl.HipGetDeviceCount() - if count == 0 { - err := fmt.Errorf("no compatible amdgpu devices detected") - slog.Info(err.Error()) - return nil, err - } - - libDir, err := AMDValidateLibDir() - if err != nil { - err = fmt.Errorf("unable to verify rocm library: %w", err) - slog.Warn(err.Error()) - return nil, err - } - - var supported []string - gfxOverride := envconfig.HsaOverrideGfxVersion() - if gfxOverride == "" { - supported, err = GetSupportedGFX(libDir) - if err != nil { - err = fmt.Errorf("failed to lookup supported GFX types: %w", err) - slog.Warn(err.Error()) - return nil, err - } - } else { - slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) - } - - slog.Debug("detected hip devices", "count", count) - // TODO how to determine the underlying device ID when visible devices is causing this to subset? - for i := range count { - err = hl.HipSetDevice(i) - if err != nil { - slog.Warn("set device", "id", i, "error", err) - continue - } - - props, err := hl.HipGetDeviceProperties(i) - if err != nil { - slog.Warn("get properties", "id", i, "error", err) - continue - } - n := bytes.IndexByte(props.Name[:], 0) - name := string(props.Name[:n]) - // TODO is UUID actually populated on windows? - // Can luid be used on windows for setting visible devices (and is it actually set?) - n = bytes.IndexByte(props.GcnArchName[:], 0) - gfx := string(props.GcnArchName[:n]) - slog.Debug("hip device", "id", i, "name", name, "gfx", gfx) - // slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 - // TODO Why isn't props.iGPU accurate!? - - freeMemory, totalMemory, err := hl.HipMemGetInfo() - if err != nil { - slog.Warn("get mem info", "id", i, "error", err) - continue - } - - gpuInfo := RocmGPUInfo{ - GpuInfo: GpuInfo{ - Library: "rocm", - memInfo: memInfo{ - TotalMemory: totalMemory, - FreeMemory: freeMemory, - }, - // Free memory reporting on Windows is not reliable until we bump to ROCm v6.2 - UnreliableFreeMemory: true, - - ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices - DependencyPath: []string{libDir}, - MinimumMemory: rocmMinimumMemory, - Name: name, - Compute: gfx, - DriverMajor: driverMajor, - DriverMinor: driverMinor, - }, - index: i, - } - - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if strings.EqualFold(name, iGPUName) || totalMemory < IGPUMemLimit { - reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - continue - } - - // Strip off Target Features when comparing - if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) { - reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) - slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - // HSA_OVERRIDE_GFX_VERSION not supported on windows - continue - } else { - slog.Debug("amdgpu is supported", "gpu", i, "gpu_type", gfx) - } - - slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) - - resp = append(resp, gpuInfo) - } - - return resp, nil -} - -func AMDValidateLibDir() (string, error) { - libDir, err := commonAMDValidateLibDir() - if err == nil { - return libDir, nil - } - - // Installer payload (if we're running from some other location) - rocmTargetDir := filepath.Join(LibOllamaPath, "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ollama installed ROCm at " + rocmTargetDir) - return rocmTargetDir, nil - } - - // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this - slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") - return "", errors.New("no suitable rocm found, falling back to CPU") -} - -func (gpus RocmGPUInfoList) RefreshFreeMemory() error { - if len(gpus) == 0 { - return nil - } - hl, err := NewHipLib() - if err != nil { - slog.Debug(err.Error()) - return err - } - defer hl.Release() - - for i := range gpus { - err := hl.HipSetDevice(gpus[i].index) - if err != nil { - return err - } - freeMemory, _, err := hl.HipMemGetInfo() - if err != nil { - slog.Warn("get mem info", "id", i, "error", err) - continue - } - slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory)) - gpus[i].FreeMemory = freeMemory - } - return nil -} - -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - // TODO shouldn't happen if things are wired correctly... - slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) - continue - } - ids = append(ids, info.ID) - } - // There are 3 potential env vars to use to select GPUs. - // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows - // HIP_VISIBLE_DEVICES supports numeric IDs only - // GPU_DEVICE_ORDINAL supports numeric IDs only - return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",") -} diff --git a/discover/cpu_common.go b/discover/cpu_common.go deleted file mode 100644 index 2b9f72927..000000000 --- a/discover/cpu_common.go +++ /dev/null @@ -1,24 +0,0 @@ -package discover - -import ( - "os" - "path/filepath" - "runtime" - "strings" -) - -func IsNUMA() bool { - if runtime.GOOS != "linux" { - // numa support in llama.cpp is linux only - return false - } - ids := map[string]any{} - packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") - for _, packageId := range packageIds { - id, err := os.ReadFile(packageId) - if err == nil { - ids[strings.TrimSpace(string(id))] = struct{}{} - } - } - return len(ids) > 1 -} diff --git a/discover/gpu_linux.go b/discover/cpu_linux.go similarity index 75% rename from discover/gpu_linux.go rename to discover/cpu_linux.go index 44c53b440..c3a0ef7fa 100644 --- a/discover/gpu_linux.go +++ b/discover/cpu_linux.go @@ -4,7 +4,9 @@ import ( "bufio" "fmt" "io" + "log/slog" "os" + "path/filepath" "reflect" "regexp" "sort" @@ -13,47 +15,6 @@ import ( "github.com/ollama/ollama/format" ) -var CudartGlobs = []string{ - "/usr/local/cuda/lib64/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/libcudart.so*", - "/usr/lib/wsl/lib/libcudart.so*", - "/usr/lib/wsl/drivers/*/libcudart.so*", - "/opt/cuda/lib64/libcudart.so*", - "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/libcudart.so*", - "/usr/local/cuda/lib*/libcudart.so*", - "/usr/lib*/libcudart.so*", - "/usr/local/lib*/libcudart.so*", -} - -var NvmlGlobs = []string{} - -var NvcudaGlobs = []string{ - "/usr/local/cuda*/targets/*/lib/libcuda.so*", - "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*", - "/usr/lib/*-linux-gnu/libcuda.so*", - "/usr/lib/wsl/lib/libcuda.so*", - "/usr/lib/wsl/drivers/*/libcuda.so*", - "/opt/cuda/lib*/libcuda.so*", - "/usr/local/cuda/lib*/libcuda.so*", - "/usr/lib*/libcuda.so*", - "/usr/local/lib*/libcuda.so*", -} - -var OneapiGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*", - "/usr/lib*/libze_intel_gpu.so*", -} - -var ( - CudartMgmtName = "libcudart.so*" - NvcudaMgmtName = "libcuda.so*" - NvmlMgmtName = "" // not currently wired on linux - OneapiMgmtName = "libze_intel_gpu.so*" -) - func GetCPUMem() (memInfo, error) { var mem memInfo var total, available, free, buffers, cached, freeSwap uint64 @@ -106,16 +67,17 @@ type linuxCpuInfo struct { CoreID string `cpuinfo:"core id"` } -func GetCPUDetails() ([]CPU, error) { +func GetCPUDetails() []CPU { file, err := os.Open(CpuInfoFilename) if err != nil { - return nil, err + slog.Warn("failed to get CPU details", "error", err) + return nil } defer file.Close() return linuxCPUDetails(file) } -func linuxCPUDetails(file io.Reader) ([]CPU, error) { +func linuxCPUDetails(file io.Reader) []CPU { reColumns := regexp.MustCompile("\t+: ") scanner := bufio.NewScanner(file) cpuInfos := []linuxCpuInfo{} @@ -194,5 +156,17 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) { for _, k := range keys { result = append(result, *socketByID[k]) } - return result, nil + return result +} + +func IsNUMA() bool { + ids := map[string]any{} + packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") + for _, packageId := range packageIds { + id, err := os.ReadFile(packageId) + if err == nil { + ids[strings.TrimSpace(string(id))] = struct{}{} + } + } + return len(ids) > 1 } diff --git a/discover/gpu_linux_test.go b/discover/cpu_linux_test.go similarity index 99% rename from discover/gpu_linux_test.go rename to discover/cpu_linux_test.go index c4d64e389..3a5144780 100644 --- a/discover/gpu_linux_test.go +++ b/discover/cpu_linux_test.go @@ -2062,10 +2062,7 @@ power management: for k, v := range testCases { t.Run(k, func(t *testing.T) { buf := bytes.NewBufferString(v.input) - cpus, err := linuxCPUDetails(buf) - if err != nil { - t.Fatal(err) - } + cpus := linuxCPUDetails(buf) slog.Info("example", "scenario", k, "cpus", cpus) si := SystemInfo{ diff --git a/discover/gpu_windows.go b/discover/cpu_windows.go similarity index 82% rename from discover/gpu_windows.go rename to discover/cpu_windows.go index 2dc2f0746..5f516b5d1 100644 --- a/discover/gpu_windows.go +++ b/discover/cpu_windows.go @@ -26,29 +26,6 @@ var ( GetLogicalProcessorInformationEx = k32.NewProc("GetLogicalProcessorInformationEx") ) -var CudartGlobs = []string{ - "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", -} - -var NvmlGlobs = []string{ - "c:\\Windows\\System32\\nvml.dll", -} - -var NvcudaGlobs = []string{ - "c:\\windows\\system*\\nvcuda.dll", -} - -var OneapiGlobs = []string{ - "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll", -} - -var ( - CudartMgmtName = "cudart64_*.dll" - NvcudaMgmtName = "nvcuda.dll" - NvmlMgmtName = "nvml.dll" - OneapiMgmtName = "ze_intel_gpu64.dll" -) - func GetCPUMem() (memInfo, error) { memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx} r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus))) @@ -122,27 +99,22 @@ func (pkg *winPackage) IsMember(target *GROUP_AFFINITY) bool { } func getLogicalProcessorInformationEx() ([]byte, error) { - buf := make([]byte, 1) + buf := make([]byte, 1024) bufSize := len(buf) - ret, _, err := GetLogicalProcessorInformationEx.Call( - uintptr(RelationAll), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&bufSize)), - ) - if ret != 0 { - return nil, fmt.Errorf("failed to determine size info ret:%d %w", ret, err) + var err error + for range 3 { + var ret uintptr + ret, _, err = GetLogicalProcessorInformationEx.Call( + uintptr(RelationAll), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + ) + if ret == 1 && bufSize <= len(buf) { + return buf, nil + } + buf = make([]byte, bufSize) } - - buf = make([]byte, bufSize) - ret, _, err = GetLogicalProcessorInformationEx.Call( - uintptr(RelationAll), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&bufSize)), - ) - if ret == 0 { - return nil, fmt.Errorf("failed to gather processor information ret:%d buflen:%d %w", ret, bufSize, err) - } - return buf, nil + return nil, fmt.Errorf("unable to determine CPU details: %w", err) } func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { @@ -217,10 +189,11 @@ func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { return packages } -func GetCPUDetails() ([]CPU, error) { +func GetCPUDetails() []CPU { buf, err := getLogicalProcessorInformationEx() if err != nil { - return nil, err + slog.Warn("failed to get CPU details", "error", err) + return nil } packages := processSystemLogicalProcessorInforationList(buf) cpus := make([]CPU, len(packages)) @@ -230,5 +203,10 @@ func GetCPUDetails() ([]CPU, error) { cpus[i].EfficiencyCoreCount = pkg.efficiencyCoreCount cpus[i].ThreadCount = pkg.threadCount } - return cpus, nil + return cpus +} + +func IsNUMA() bool { + // numa support in ggml is linux only + return false } diff --git a/discover/gpu_windows_test.go b/discover/cpu_windows_test.go similarity index 100% rename from discover/gpu_windows_test.go rename to discover/cpu_windows_test.go diff --git a/discover/cuda_common.go b/discover/cuda_common.go deleted file mode 100644 index 3c7cb6698..000000000 --- a/discover/cuda_common.go +++ /dev/null @@ -1,69 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "fmt" - "log/slog" - "os" - "regexp" - "runtime" - "strconv" - "strings" -) - -// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. -// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. -var CudaTegra string = os.Getenv("JETSON_JETPACK") - -func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "cuda" { - // TODO shouldn't happen if things are wired correctly... - slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library) - continue - } - ids = append(ids, info.ID) - } - return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",") -} - -func cudaVariant(gpuInfo CudaGPUInfo) string { - if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { - if CudaTegra != "" { - ver := strings.Split(CudaTegra, ".") - if len(ver) > 0 { - return "jetpack" + ver[0] - } - } else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil { - r := regexp.MustCompile(` R(\d+) `) - m := r.FindSubmatch(data) - if len(m) != 2 { - slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version") - } else { - if l4t, err := strconv.Atoi(string(m[1])); err == nil { - // Note: mapping from L4t -> JP is inconsistent (can't just subtract 30) - // https://developer.nvidia.com/embedded/jetpack-archive - switch l4t { - case 35: - return "jetpack5" - case 36: - return "jetpack6" - default: - slog.Info("unsupported L4T version", "nv_tegra_release", string(data)) - } - } - } - } - return "sbsa" - } - - // driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers - if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { - // The detected driver is older than Feb 2023 - slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) - return "v11" - } - return "v12" -} diff --git a/discover/gpu.go b/discover/gpu.go index f6e3c9cb1..9175906d1 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -1,720 +1,148 @@ -//go:build linux || windows - package discover -/* -#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm -#cgo windows LDFLAGS: -lpthread - -#include "gpu_info.h" -*/ -import "C" - import ( + "context" "fmt" "log/slog" "os" "path/filepath" "runtime" - "strconv" "strings" - "sync" - "unsafe" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" ) -type cudaHandles struct { - deviceCount int - cudart *C.cudart_handle_t - nvcuda *C.nvcuda_handle_t - nvml *C.nvml_handle_t +// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. +// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. +var CudaTegra string = os.Getenv("JETSON_JETPACK") + +func GetCPUInfo() GpuInfo { + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } + + return GpuInfo{ + memInfo: mem, + DeviceID: ml.DeviceID{ + Library: "cpu", + ID: "0", + }, + } } -type oneapiHandles struct { - oneapi *C.oneapi_handle_t - deviceCount int +func GetGPUInfo(ctx context.Context, runners []FilteredRunnerDiscovery) GpuInfoList { + devs := GPUDevices(ctx, runners) + return devInfoToInfoList(devs) } -const ( - cudaMinimumMemory = 457 * format.MebiByte - rocmMinimumMemory = 457 * format.MebiByte - // TODO OneAPI minimum memory -) - -var ( - gpuMutex sync.Mutex - bootstrapped bool - cpus []CPUInfo - cudaGPUs []CudaGPUInfo - nvcudaLibPath string - cudartLibPath string - oneapiLibPath string - nvmlLibPath string - rocmGPUs []RocmGPUInfo - oneapiGPUs []OneapiGPUInfo - - // If any discovered GPUs are incompatible, report why - unsupportedGPUs []UnsupportedGPUInfo - - // Keep track of errors during bootstrapping so that if GPUs are missing - // they expected to be present this may explain why - bootstrapErrors []error -) - -// With our current CUDA compile flags, older than 5.0 will not work properly -// (string values used to allow ldflags overrides at build time) -var ( - CudaComputeMajorMin = "5" - CudaComputeMinorMin = "0" -) - -var RocmComputeMajorMin = "9" - -// TODO find a better way to detect iGPU instead of minimum memory -const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU - -// Note: gpuMutex must already be held -func initCudaHandles() *cudaHandles { - // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - - cHandles := &cudaHandles{} - // Short Circuit if we already know which library to use - // ignore bootstrap errors in this case since we already recorded them - if nvmlLibPath != "" { - cHandles.nvml, _, _ = loadNVMLMgmt([]string{nvmlLibPath}) - return cHandles - } - if nvcudaLibPath != "" { - cHandles.deviceCount, cHandles.nvcuda, _, _ = loadNVCUDAMgmt([]string{nvcudaLibPath}) - return cHandles - } - if cudartLibPath != "" { - cHandles.deviceCount, cHandles.cudart, _, _ = loadCUDARTMgmt([]string{cudartLibPath}) - return cHandles - } - - slog.Debug("searching for GPU discovery libraries for NVIDIA") - var cudartMgmtPatterns []string - - // Aligned with driver, we can't carry as payloads - nvcudaMgmtPatterns := NvcudaGlobs - cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(LibOllamaPath, "cuda_v*", CudartMgmtName)) - cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...) - - if len(NvmlGlobs) > 0 { - nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs) - if len(nvmlLibPaths) > 0 { - nvml, libPath, err := loadNVMLMgmt(nvmlLibPaths) - if nvml != nil { - slog.Debug("nvidia-ml loaded", "library", libPath) - cHandles.nvml = nvml - nvmlLibPath = libPath - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - } - - nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns) - if len(nvcudaLibPaths) > 0 { - deviceCount, nvcuda, libPath, err := loadNVCUDAMgmt(nvcudaLibPaths) - if nvcuda != nil { - slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) - cHandles.nvcuda = nvcuda - cHandles.deviceCount = deviceCount - nvcudaLibPath = libPath - return cHandles - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns) - if len(cudartLibPaths) > 0 { - deviceCount, cudart, libPath, err := loadCUDARTMgmt(cudartLibPaths) - if cudart != nil { - slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) - cHandles.cudart = cudart - cHandles.deviceCount = deviceCount - cudartLibPath = libPath - return cHandles - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - return cHandles -} - -// Note: gpuMutex must already be held -func initOneAPIHandles() *oneapiHandles { - oHandles := &oneapiHandles{} - - // Short Circuit if we already know which library to use - // ignore bootstrap errors in this case since we already recorded them - if oneapiLibPath != "" { - oHandles.deviceCount, oHandles.oneapi, _, _ = loadOneapiMgmt([]string{oneapiLibPath}) - return oHandles - } - - oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs) - if len(oneapiLibPaths) > 0 { - var err error - oHandles.deviceCount, oHandles.oneapi, oneapiLibPath, err = loadOneapiMgmt(oneapiLibPaths) - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - return oHandles -} - -func GetCPUInfo() GpuInfoList { - gpuMutex.Lock() - if !bootstrapped { - gpuMutex.Unlock() - GetGPUInfo() - } else { - gpuMutex.Unlock() - } - return GpuInfoList{cpus[0].GpuInfo} -} - -func GetGPUInfo() GpuInfoList { - // TODO - consider exploring lspci (and equivalent on windows) to check for - // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries - gpuMutex.Lock() - defer gpuMutex.Unlock() - needRefresh := true - var cHandles *cudaHandles - var oHandles *oneapiHandles - defer func() { - if cHandles != nil { - if cHandles.cudart != nil { - C.cudart_release(*cHandles.cudart) - } - if cHandles.nvcuda != nil { - C.nvcuda_release(*cHandles.nvcuda) - } - if cHandles.nvml != nil { - C.nvml_release(*cHandles.nvml) - } - } - if oHandles != nil { - if oHandles.oneapi != nil { - // TODO - is this needed? - C.oneapi_release(*oHandles.oneapi) - } - } - }() - - if !bootstrapped { - slog.Info("looking for compatible GPUs") - cudaComputeMajorMin, err := strconv.Atoi(CudaComputeMajorMin) - if err != nil { - slog.Error("invalid CudaComputeMajorMin setting", "value", CudaComputeMajorMin, "error", err) - } - cudaComputeMinorMin, err := strconv.Atoi(CudaComputeMinorMin) - if err != nil { - slog.Error("invalid CudaComputeMinorMin setting", "value", CudaComputeMinorMin, "error", err) - } - bootstrapErrors = []error{} - needRefresh = false - var memInfo C.mem_info_t - - mem, err := GetCPUMem() - if err != nil { - slog.Warn("error looking up system memory", "error", err) - } - - details, err := GetCPUDetails() - if err != nil { - slog.Warn("failed to lookup CPU details", "error", err) - } - cpus = []CPUInfo{ - { - GpuInfo: GpuInfo{ - memInfo: mem, - Library: "cpu", - ID: "0", - }, - CPUs: details, - }, - } - - // Load ALL libraries - cHandles = initCudaHandles() - - // NVIDIA - for i := range cHandles.deviceCount { - if cHandles.cudart != nil || cHandles.nvcuda != nil { - gpuInfo := CudaGPUInfo{ - GpuInfo: GpuInfo{ - Library: "cuda", - }, - index: i, - } - var driverMajor int - var driverMinor int - if cHandles.cudart != nil { - C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo) - driverMajor = int(cHandles.cudart.driver_major) - driverMinor = int(cHandles.cudart.driver_minor) - } else { - C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo) - driverMajor = int(cHandles.nvcuda.driver_major) - driverMinor = int(cHandles.nvcuda.driver_minor) - } - if memInfo.err != nil { - slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) - gpuInfo.computeMajor = int(memInfo.major) - gpuInfo.computeMinor = int(memInfo.minor) - gpuInfo.MinimumMemory = cudaMinimumMemory - gpuInfo.DriverMajor = driverMajor - gpuInfo.DriverMinor = driverMinor - variant := cudaVariant(gpuInfo) - - // Start with our bundled libraries - if variant != "" { - variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant) - if _, err := os.Stat(variantPath); err == nil { - // Put the variant directory first in the search path to avoid runtime linking to the wrong library - gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...) - } - } - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.Variant = variant - - if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { - unsupportedGPUs = append(unsupportedGPUs, - UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - }) - slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) - continue - } - - // query the management library as well so we can record any skew between the two - // which represents overhead on the GPU we must set aside on subsequent updates - if cHandles.nvml != nil { - uuid := C.CString(gpuInfo.ID) - defer C.free(unsafe.Pointer(uuid)) - C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) - if memInfo.err != nil { - slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - } else { - if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory { - gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory - slog.Info("detected OS VRAM overhead", - "id", gpuInfo.ID, - "library", gpuInfo.Library, - "compute", gpuInfo.Compute, - "driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor), - "name", gpuInfo.Name, - "overhead", format.HumanBytes2(gpuInfo.OSOverhead), - ) - } - } - } - - // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - cudaGPUs = append(cudaGPUs, gpuInfo) - } - } - - // Intel - if envconfig.IntelGPU() { - oHandles = initOneAPIHandles() - if oHandles != nil && oHandles.oneapi != nil { - for d := range oHandles.oneapi.num_drivers { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) - continue - } - devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) - for i := range devCount { - gpuInfo := OneapiGPUInfo{ - GpuInfo: GpuInfo{ - Library: "oneapi", - }, - driverIndex: int(d), - gpuIndex: int(i), - } - // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DependencyPath = []string{LibOllamaPath} - oneapiGPUs = append(oneapiGPUs, gpuInfo) - } - } - } - } - - rocmGPUs, err = AMDGetGPUInfo() - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - bootstrapped = true - if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 { - slog.Info("no compatible GPUs were discovered") - } - - // TODO verify we have runners for the discovered GPUs, filter out any that aren't supported with good error messages - } - - // For detected GPUs, load library if not loaded - - // Refresh free memory usage - if needRefresh { - mem, err := GetCPUMem() - if err != nil { - slog.Warn("error looking up system memory", "error", err) - } else { - slog.Debug("updating system memory data", - slog.Group( - "before", - "total", format.HumanBytes2(cpus[0].TotalMemory), - "free", format.HumanBytes2(cpus[0].FreeMemory), - "free_swap", format.HumanBytes2(cpus[0].FreeSwap), - ), - slog.Group( - "now", - "total", format.HumanBytes2(mem.TotalMemory), - "free", format.HumanBytes2(mem.FreeMemory), - "free_swap", format.HumanBytes2(mem.FreeSwap), - ), - ) - cpus[0].FreeMemory = mem.FreeMemory - cpus[0].FreeSwap = mem.FreeSwap - } - - var memInfo C.mem_info_t - if cHandles == nil && len(cudaGPUs) > 0 { - cHandles = initCudaHandles() - } - for i, gpu := range cudaGPUs { - if cHandles.nvml != nil { - uuid := C.CString(gpu.ID) - defer C.free(unsafe.Pointer(uuid)) - C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) - } else if cHandles.cudart != nil { - C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo) - } else if cHandles.nvcuda != nil { - C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total) - memInfo.used = memInfo.total - memInfo.free - } else { - // shouldn't happen - slog.Warn("no valid cuda library loaded to refresh vram usage") - break - } - if memInfo.err != nil { - slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - if memInfo.free == 0 { - slog.Warn("error looking up nvidia GPU memory") - continue - } - if cHandles.nvml != nil && gpu.OSOverhead > 0 { - // When using the management library update based on recorded overhead - memInfo.free -= C.uint64_t(gpu.OSOverhead) - } - slog.Debug("updating cuda memory data", - "gpu", gpu.ID, - "name", gpu.Name, - "overhead", format.HumanBytes2(gpu.OSOverhead), - slog.Group( - "before", - "total", format.HumanBytes2(gpu.TotalMemory), - "free", format.HumanBytes2(gpu.FreeMemory), - ), - slog.Group( - "now", - "total", format.HumanBytes2(uint64(memInfo.total)), - "free", format.HumanBytes2(uint64(memInfo.free)), - "used", format.HumanBytes2(uint64(memInfo.used)), - ), - ) - cudaGPUs[i].FreeMemory = uint64(memInfo.free) - } - - if oHandles == nil && len(oneapiGPUs) > 0 { - oHandles = initOneAPIHandles() - } - for i, gpu := range oneapiGPUs { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount) - continue - } - C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - oneapiGPUs[i].FreeMemory = uint64(memInfo.free) - } - - err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory() - if err != nil { - slog.Debug("problem refreshing ROCm free memory", "error", err) - } - } - +func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { resp := []GpuInfo{} - for _, gpu := range cudaGPUs { - resp = append(resp, gpu.GpuInfo) + // Our current packaging model places ggml-hip in the main directory + // but keeps rocm in an isolated directory. We have to add it to + // the [LD_LIBRARY_]PATH so ggml-hip will load properly + rocmDir := filepath.Join(LibOllamaPath, "rocm") + if _, err := os.Stat(rocmDir); err != nil { + rocmDir = "" } - for _, gpu := range rocmGPUs { - resp = append(resp, gpu.GpuInfo) - } - for _, gpu := range oneapiGPUs { - resp = append(resp, gpu.GpuInfo) + + for _, dev := range devs { + info := GpuInfo{ + DeviceID: dev.DeviceID, + filterID: dev.FilteredID, + Name: dev.Description, + memInfo: memInfo{ + TotalMemory: dev.TotalMemory, + FreeMemory: dev.FreeMemory, + }, + // TODO can we avoid variant + DependencyPath: dev.LibraryPath, + DriverMajor: dev.DriverMajor, + DriverMinor: dev.DriverMinor, + } + if dev.Library == "CUDA" || dev.Library == "ROCm" { + info.MinimumMemory = 457 * format.MebiByte + } + if dev.Library == "ROCm" { + info.Compute = fmt.Sprintf("gfx%x%02x", dev.ComputeMajor, dev.ComputeMinor) + if rocmDir != "" { + info.DependencyPath = append(info.DependencyPath, rocmDir) + } + } else { + info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor) + } + resp = append(resp, info) } if len(resp) == 0 { - resp = append(resp, cpus[0].GpuInfo) + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } + + resp = append(resp, GpuInfo{ + memInfo: mem, + DeviceID: ml.DeviceID{ + Library: "cpu", + ID: "0", + }, + }) } return resp } -func FindGPULibs(baseLibName string, defaultPatterns []string) []string { - // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them - gpuLibPaths := []string{} - slog.Debug("Searching for GPU library", "name", baseLibName) - - // search our bundled libraries first - patterns := []string{filepath.Join(LibOllamaPath, baseLibName)} - - var ldPaths []string - switch runtime.GOOS { - case "windows": - ldPaths = strings.Split(os.Getenv("PATH"), string(os.PathListSeparator)) - case "linux": - ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), string(os.PathListSeparator)) - } - - // then search the system's LD_LIBRARY_PATH - for _, p := range ldPaths { - p, err := filepath.Abs(p) - if err != nil { - continue - } - patterns = append(patterns, filepath.Join(p, baseLibName)) - } - - // finally, search the default patterns provided by the caller - patterns = append(patterns, defaultPatterns...) - slog.Debug("gpu library search", "globs", patterns) - for _, pattern := range patterns { - // Nvidia PhysX known to return bogus results - if strings.Contains(pattern, "PhysX") { - slog.Debug("skipping PhysX cuda library path", "path", pattern) - continue - } - // Ignore glob discovery errors - matches, _ := filepath.Glob(pattern) - for _, match := range matches { - // Resolve any links so we don't try the same lib multiple times - // and weed out any dups across globs - libPath := match - tmp := match - var err error - for ; err == nil; tmp, err = os.Readlink(libPath) { - if !filepath.IsAbs(tmp) { - tmp = filepath.Join(filepath.Dir(libPath), tmp) - } - libPath = tmp - } - new := true - for _, cmp := range gpuLibPaths { - if cmp == libPath { - new = false - break - } - } - if new { - gpuLibPaths = append(gpuLibPaths, libPath) - } - } - } - slog.Debug("discovered GPU libraries", "paths", gpuLibPaths) - return gpuLibPaths -} - -// Bootstrap the runtime library -// Returns: num devices, handle, libPath, error -func loadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string, error) { - var resp C.cudart_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range cudartLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.cudart_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) - slog.Debug(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return int(resp.num_devices), &resp.ch, libPath, err - } - } - return 0, nil, "", err -} - -// Bootstrap the driver library -// Returns: num devices, handle, libPath, error -func loadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string, error) { - var resp C.nvcuda_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range nvcudaLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvcuda_init(lib, &resp) - if resp.err != nil { - // Decide what log level based on the type of error message to help users understand why - switch resp.cudaErr { - case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: - err = fmt.Errorf("version mismatch between driver and cuda driver library - reboot or upgrade may be required: library %s", libPath) - slog.Warn(err.Error()) - case C.CUDA_ERROR_NO_DEVICE: - err = fmt.Errorf("no nvidia devices detected by library %s", libPath) - slog.Info(err.Error()) - case C.CUDA_ERROR_UNKNOWN: - err = fmt.Errorf("unknown error initializing cuda driver library %s: %s. see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information", libPath, C.GoString(resp.err)) - slog.Warn(err.Error()) - default: - msg := C.GoString(resp.err) - if strings.Contains(msg, "wrong ELF class") { - slog.Debug("skipping 32bit library", "library", libPath) - } else { - err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) - slog.Info(err.Error()) - } - } - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return int(resp.num_devices), &resp.ch, libPath, err - } - } - return 0, nil, "", err -} - -// Bootstrap the management library -// Returns: handle, libPath, error -func loadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string, error) { - var resp C.nvml_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range nvmlLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvml_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)) - slog.Info(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return &resp.ch, libPath, err - } - } - return nil, "", err -} - -// bootstrap the Intel GPU library -// Returns: num devices, handle, libPath, error -func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, error) { - var resp C.oneapi_init_resp_t - num_devices := 0 - resp.oh.verbose = getVerboseState() - var err error - for _, libPath := range oneapiLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.oneapi_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load oneAPI management library %s: %s", libPath, C.GoString(resp.err)) - slog.Debug(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - for i := range resp.oh.num_drivers { - num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i))) - } - return num_devices, &resp.oh, libPath, err - } - } - return 0, nil, "", err -} - -func getVerboseState() C.uint16_t { - if envconfig.LogLevel() < slog.LevelInfo { - return C.uint16_t(1) - } - return C.uint16_t(0) -} - // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variable // // If different libraries are detected, the first one is what we use -func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { +func (l GpuInfoList) GetVisibleDevicesEnv() []string { if len(l) == 0 { - return "", "" - } - switch l[0].Library { - case "cuda": - return cudaGetVisibleDevicesEnv(l) - case "rocm": - return rocmGetVisibleDevicesEnv(l) - case "oneapi": - return oneapiGetVisibleDevicesEnv(l) - default: - slog.Debug("no filter required for library " + l[0].Library) - return "", "" + return nil } + return []string{rocmGetVisibleDevicesEnv(l)} } -func GetSystemInfo() SystemInfo { - gpus := GetGPUInfo() - gpuMutex.Lock() - defer gpuMutex.Unlock() - discoveryErrors := []string{} - for _, err := range bootstrapErrors { - discoveryErrors = append(discoveryErrors, err.Error()) +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "ROCm" { + continue + } + // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number + if info.filterID != "" { + ids = append(ids, info.filterID) + } else { + ids = append(ids, info.ID) + } } + if len(ids) == 0 { + return "" + } + envVar := "ROCR_VISIBLE_DEVICES=" + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES=" + } + // There are 3 potential env vars to use to select GPUs. + // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows + // HIP_VISIBLE_DEVICES supports numeric IDs only + // GPU_DEVICE_ORDINAL supports numeric IDs only + return envVar + strings.Join(ids, ",") +} + +// GetSystemInfo returns the last cached state of the GPUs on the system +func GetSystemInfo() SystemInfo { + deviceMu.Lock() + defer deviceMu.Unlock() + gpus := devInfoToInfoList(devices) if len(gpus) == 1 && gpus[0].Library == "cpu" { gpus = []GpuInfo{} } return SystemInfo{ - System: cpus[0], - GPUs: gpus, - UnsupportedGPUs: unsupportedGPUs, - DiscoveryErrors: discoveryErrors, + System: CPUInfo{ + CPUs: GetCPUDetails(), + GpuInfo: GetCPUInfo(), + }, + GPUs: gpus, } } diff --git a/discover/gpu_darwin.go b/discover/gpu_darwin.go index dd5bf6e27..6f55b4c57 100644 --- a/discover/gpu_darwin.go +++ b/discover/gpu_darwin.go @@ -1,5 +1,3 @@ -//go:build darwin - package discover /* @@ -11,7 +9,6 @@ import "C" import ( "log/slog" - "runtime" "syscall" "github.com/ollama/ollama/format" @@ -21,39 +18,6 @@ const ( metalMinimumMemory = 512 * format.MebiByte ) -func GetGPUInfo() GpuInfoList { - mem, _ := GetCPUMem() - if runtime.GOARCH == "amd64" { - return []GpuInfo{ - { - Library: "cpu", - memInfo: mem, - }, - } - } - info := GpuInfo{ - Library: "metal", - ID: "0", - } - info.TotalMemory = uint64(C.getRecommendedMaxVRAM()) - - // TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work) - info.FreeMemory = info.TotalMemory - - info.MinimumMemory = metalMinimumMemory - return []GpuInfo{info} -} - -func GetCPUInfo() GpuInfoList { - mem, _ := GetCPUMem() - return []GpuInfo{ - { - Library: "cpu", - memInfo: mem, - }, - } -} - func GetCPUMem() (memInfo, error) { return memInfo{ TotalMemory: uint64(C.getPhysicalMemory()), @@ -62,13 +26,7 @@ func GetCPUMem() (memInfo, error) { }, nil } -func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { - // No-op on darwin - return "", "" -} - -func GetSystemInfo() SystemInfo { - mem, _ := GetCPUMem() +func GetCPUDetails() []CPU { query := "hw.perflevel0.physicalcpu" perfCores, err := syscall.SysctlUint32(query) if err != nil { @@ -81,19 +39,16 @@ func GetSystemInfo() SystemInfo { query = "hw.logicalcpu" logicalCores, _ := syscall.SysctlUint32(query) - return SystemInfo{ - System: CPUInfo{ - GpuInfo: GpuInfo{ - memInfo: mem, - }, - CPUs: []CPU{ - { - CoreCount: int(perfCores + efficiencyCores), - EfficiencyCoreCount: int(efficiencyCores), - ThreadCount: int(logicalCores), - }, - }, + return []CPU{ + { + CoreCount: int(perfCores + efficiencyCores), + EfficiencyCoreCount: int(efficiencyCores), + ThreadCount: int(logicalCores), }, - GPUs: GetGPUInfo(), } } + +func IsNUMA() bool { + // numa support in ggml is linux only + return false +} diff --git a/discover/gpu_info.h b/discover/gpu_info.h deleted file mode 100644 index ee7ff4c33..000000000 --- a/discover/gpu_info.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_H__ -#define __GPU_INFO_H__ -#include -#include -#include - -#ifndef _WIN32 -#include -#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags) -#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym) -#define LOAD_ERR() strdup(dlerror()) -#define UNLOAD_LIBRARY(handle) dlclose(handle) -#else -#include -#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib) -#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym) -#define UNLOAD_LIBRARY(handle) FreeLibrary(handle) -#define LOAD_ERR() ({\ - LPSTR messageBuffer = NULL; \ - size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \ - NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \ - char *resp = strdup(messageBuffer); \ - LocalFree(messageBuffer); \ - resp; \ -}) - -#endif - -#ifndef LOG -#define LOG(verbose, ...) \ - do { \ - if (verbose) { \ - fprintf(stderr, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#define GPU_ID_LEN 64 -#define GPU_NAME_LEN 96 - -typedef struct mem_info { - char *err; // If non-nill, caller responsible for freeing - char gpu_id[GPU_ID_LEN]; - char gpu_name[GPU_NAME_LEN]; - uint64_t total; - uint64_t free; - uint64_t used; - - // Compute Capability - int major; - int minor; - int patch; -} mem_info_t; - -void cpu_check_ram(mem_info_t *resp); - -#ifdef __cplusplus -} -#endif - -#include "gpu_info_cudart.h" -#include "gpu_info_nvcuda.h" -#include "gpu_info_nvml.h" -#include "gpu_info_oneapi.h" - -#endif // __GPU_INFO_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_cudart.c b/discover/gpu_info_cudart.c deleted file mode 100644 index 76c17b9d8..000000000 --- a/discover/gpu_info_cudart.c +++ /dev/null @@ -1,181 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include -#include -#include "gpu_info_cudart.h" - -void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { - cudartReturn_t ret; - resp->err = NULL; - resp->num_devices = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"cudaSetDevice", (void *)&resp->ch.cudaSetDevice}, - {"cudaDeviceSynchronize", (void *)&resp->ch.cudaDeviceSynchronize}, - {"cudaDeviceReset", (void *)&resp->ch.cudaDeviceReset}, - {"cudaMemGetInfo", (void *)&resp->ch.cudaMemGetInfo}, - {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, - {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, - {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, - {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(cudart_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", cudart_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - cudart_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.cudaSetDevice)(0); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) { - resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); - return; - } - snprintf(buf, buflen, "cudart init failure: %d", ret); - resp->err = strdup(buf); - return; - } - - int version = 0; - - // Report driver version if we're in verbose mode, ignore errors - ret = (*resp->ch.cudaDriverGetVersion)(&version); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret); - } else { - resp->ch.driver_major = version / 1000; - resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; - LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", resp->ch.driver_major, resp->ch.driver_minor); - } - - ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } -} - - -void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) { - resp->err = NULL; - cudartMemory_t memInfo = {0,0,0}; - cudartReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - - if (h.handle == NULL) { - resp->err = strdup("cudart handle isn't initialized"); - return; - } - - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; - } - - cudaDeviceProp_t props; - ret = (*h.cudaGetDeviceProperties)(&props, i); - if (ret != CUDART_SUCCESS) { - LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret); - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - resp->major = 0; - resp->minor = 0; - } else { - int allNull = 1; - for (int j = 0; j < 16; j++) { - if (props.uuid.bytes[j] != 0) { - allNull = 0; - break; - } - } - if (allNull != 0) { - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - } else { - // GPU-d110a105-ac29-1d54-7b49-9c90440f215b - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - props.uuid.bytes[0], - props.uuid.bytes[1], - props.uuid.bytes[2], - props.uuid.bytes[3], - props.uuid.bytes[4], - props.uuid.bytes[5], - props.uuid.bytes[6], - props.uuid.bytes[7], - props.uuid.bytes[8], - props.uuid.bytes[9], - props.uuid.bytes[10], - props.uuid.bytes[11], - props.uuid.bytes[12], - props.uuid.bytes[13], - props.uuid.bytes[14], - props.uuid.bytes[15] - ); - } - resp->major = props.major; - resp->minor = props.minor; - - // TODO add other useful properties from props - } - ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); - resp->err = strdup(buf); - return; - } - - resp->total = memInfo.total; - resp->free = memInfo.free; - resp->used = memInfo.used; - - LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total); - LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free); - LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used); - LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); -} - -void cudart_release(cudart_handle_t h) { - LOG(h.verbose, "releasing cudart library\n"); - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_cudart.h b/discover/gpu_info_cudart.h deleted file mode 100644 index 893f3f7bd..000000000 --- a/discover/gpu_info_cudart.h +++ /dev/null @@ -1,145 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_CUDART_H__ -#define __GPU_INFO_CUDART_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum cudartReturn_enum { - CUDART_SUCCESS = 0, - CUDART_ERROR_INVALID_VALUE = 1, - CUDART_ERROR_MEMORY_ALLOCATION = 2, - CUDART_ERROR_INSUFFICIENT_DRIVER = 35, - // Other values omitted for now... -} cudartReturn_t; - -typedef enum cudartDeviceAttr_enum { - cudartDevAttrComputeCapabilityMajor = 75, - cudartDevAttrComputeCapabilityMinor = 76, - - // TODO - not yet wired up but may be useful for Jetson or other - // integrated GPU scenarios with shared memory - cudaDevAttrIntegrated = 18 - -} cudartDeviceAttr_t; - -typedef void *cudartDevice_t; // Opaque is sufficient -typedef struct cudartMemory_st { - size_t total; - size_t free; - size_t used; -} cudartMemory_t; - -typedef struct cudaUUID { - unsigned char bytes[16]; -} cudaUUID_t; -typedef struct cudaDeviceProp { - char name[256]; /**< ASCII string identifying device */ - cudaUUID_t uuid; /**< 16-byte unique identifier */ - char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */ - unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */ - size_t totalGlobalMem; /**< Global memory available on device in bytes */ - size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */ - int regsPerBlock; /**< 32-bit registers available per block */ - int warpSize; /**< Warp size in threads */ - size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */ - int maxThreadsPerBlock; /**< Maximum number of threads per block */ - int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ - int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ - int clockRate; /**< Clock frequency in kilohertz */ - size_t totalConstMem; /**< Constant memory available on device in bytes */ - int major; /**< Major compute capability */ - int minor; /**< Minor compute capability */ - size_t textureAlignment; /**< Alignment requirement for textures */ - size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */ - int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */ - int multiProcessorCount; /**< Number of multiprocessors on device */ - int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */ - int integrated; /**< Device is integrated as opposed to discrete */ - int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */ - int computeMode; /**< Compute mode (See ::cudaComputeMode) */ - int maxTexture1D; /**< Maximum 1D texture size */ - int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */ - int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ - int maxTexture2D[2]; /**< Maximum 2D texture dimensions */ - int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */ - int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */ - int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */ - int maxTexture3D[3]; /**< Maximum 3D texture dimensions */ - int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */ - int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */ - int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */ - int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */ - int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */ - int maxSurface1D; /**< Maximum 1D surface size */ - int maxSurface2D[2]; /**< Maximum 2D surface dimensions */ - int maxSurface3D[3]; /**< Maximum 3D surface dimensions */ - int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */ - int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */ - int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */ - int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */ - size_t surfaceAlignment; /**< Alignment requirements for surfaces */ - int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */ - int ECCEnabled; /**< Device has ECC support enabled */ - int pciBusID; /**< PCI bus ID of the device */ - int pciDeviceID; /**< PCI device ID of the device */ - int pciDomainID; /**< PCI domain ID of the device */ - int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */ - int asyncEngineCount; /**< Number of asynchronous engines */ - int unifiedAddressing; /**< Device shares a unified address space with the host */ - int memoryClockRate; /**< Peak memory clock frequency in kilohertz */ - int memoryBusWidth; /**< Global memory bus width in bits */ - int l2CacheSize; /**< Size of L2 cache in bytes */ - int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */ - int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */ - int streamPrioritiesSupported; /**< Device supports stream priorities */ - int globalL1CacheSupported; /**< Device supports caching globals in L1 */ - int localL1CacheSupported; /**< Device supports caching locals in L1 */ - size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */ - int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */ - int managedMemory; /**< Device supports allocating managed memory on this system */ - int isMultiGpuBoard; /**< Device is on a multi-GPU board */ - int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */ - int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */ - int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ - int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ - int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */ - int computePreemptionSupported; /**< Device supports Compute Preemption */ - int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */ - int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */ - int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */ - size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */ - int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */ - int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */ - int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */ - int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */ - size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */ - } cudaDeviceProp_t; - -typedef struct cudart_handle { - void *handle; - uint16_t verbose; - int driver_major; - int driver_minor; - cudartReturn_t (*cudaSetDevice)(int device); - cudartReturn_t (*cudaDeviceSynchronize)(void); - cudartReturn_t (*cudaDeviceReset)(void); - cudartReturn_t (*cudaMemGetInfo)(size_t *, size_t *); - cudartReturn_t (*cudaGetDeviceCount)(int *); - cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); - cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); - cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device); -} cudart_handle_t; - -typedef struct cudart_init_resp { - char *err; // If err is non-null handle is invalid - cudart_handle_t ch; - int num_devices; -} cudart_init_resp_t; - -void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); -void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp); -// TODO - if we keep this library longer term, add cudart_get_free -void cudart_release(cudart_handle_t ch); - -#endif // __GPU_INFO_CUDART_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_nvcuda.c b/discover/gpu_info_nvcuda.c deleted file mode 100644 index d2d0b683b..000000000 --- a/discover/gpu_info_nvcuda.c +++ /dev/null @@ -1,251 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include -#include -#include "gpu_info_nvcuda.h" - -void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { - LOG(resp->ch.verbose, "initializing %s\n", nvcuda_lib_path); - CUresult ret; - resp->err = NULL; - resp->num_devices = 0; - resp->cudaErr = CUDA_SUCCESS; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - - {"cuInit", (void *)&resp->ch.cuInit}, - {"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion}, - {"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount}, - {"cuDeviceGet", (void *)&resp->ch.cuDeviceGet}, - {"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute}, - {"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid}, - {"cuDeviceGetName", (void *)&resp->ch.cuDeviceGetName}, - {"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3}, - {"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2}, - {"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvcuda_lib_path, msg); - free(msg); - resp->err = strdup(buf); - resp->cudaErr = -1; - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - resp->cudaErr = -1; - return; - } - LOG(resp->ch.verbose, "dlsym: %s - %p\n", l[i].s, *l[i].p); - } - - LOG(resp->ch.verbose, "calling cuInit\n"); - ret = (*resp->ch.cuInit)(0); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuInit err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "cuda driver library init failure: %d", ret); - resp->err = strdup(buf); - resp->cudaErr = ret; - return; - } - - int version = 0; - resp->ch.driver_major = 0; - resp->ch.driver_minor = 0; - - // Report driver version if we're in verbose mode, ignore errors - LOG(resp->ch.verbose, "calling cuDriverGetVersion\n"); - ret = (*resp->ch.cuDriverGetVersion)(&version); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret); - } else { - LOG(resp->ch.verbose, "raw version 0x%x\n", version); - resp->ch.driver_major = version / 1000; - resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; - LOG(resp->ch.verbose, "CUDA driver version: %d.%d\n", resp->ch.driver_major, resp->ch.driver_minor); - } - - LOG(resp->ch.verbose, "calling cuDeviceGetCount\n"); - ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - resp->cudaErr = ret; - return; - } - LOG(resp->ch.verbose, "device count %d\n", resp->num_devices); -} - -const int buflen = 256; -void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) { - resp->err = NULL; - nvcudaMemory_t memInfo = {0,0}; - CUresult ret; - CUdevice device = -1; - CUcontext ctx = NULL; - char buf[buflen + 1]; - CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; - - if (h.handle == NULL) { - resp->err = strdup("cuda driver library handle isn't initialized"); - return; - } - - ret = (*h.cuDeviceGet)(&device, i); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library device failed to initialize"); - resp->err = strdup(buf); - return; - } - - int major = 0; - int minor = 0; - ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret); - } else { - ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret); - } else { - resp->minor = minor; - resp->major = major; - } - } - - ret = (*h.cuDeviceGetUuid)(&uuid, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret); - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - } else { - // GPU-d110a105-ac29-1d54-7b49-9c90440f215b - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - uuid.bytes[0], - uuid.bytes[1], - uuid.bytes[2], - uuid.bytes[3], - uuid.bytes[4], - uuid.bytes[5], - uuid.bytes[6], - uuid.bytes[7], - uuid.bytes[8], - uuid.bytes[9], - uuid.bytes[10], - uuid.bytes[11], - uuid.bytes[12], - uuid.bytes[13], - uuid.bytes[14], - uuid.bytes[15] - ); - } - - ret = (*h.cuDeviceGetName)(&resp->gpu_name[0], GPU_NAME_LEN, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device name lookup failure: %d\n", i, ret); - resp->gpu_name[0] = '\0'; - } - - // To get memory we have to set (and release) a context - ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret); - resp->err = strdup(buf); - // Best effort on failure... - (*h.cuCtxDestroy)(ctx); - return; - } - - resp->total = memInfo.total; - resp->free = memInfo.free; - - LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024); - LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024); - LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); - - - - ret = (*h.cuCtxDestroy)(ctx); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to release device context %d", ret); - } -} - -void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) { - CUresult ret; - CUcontext ctx = NULL; - CUdevice device = -1; - *free = 0; - *total = 0; - - ret = (*h.cuDeviceGet)(&device, i); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library device failed to initialize"); - return; - } - - - // To get memory we have to set (and release) a context - ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to get device context %d", ret); - return; - } - - ret = (*h.cuMemGetInfo_v2)(free, total); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library device memory info lookup failure %d", ret); - // Best effort on failure... - (*h.cuCtxDestroy)(ctx); - return; - } - - ret = (*h.cuCtxDestroy)(ctx); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to release device context %d", ret); - } -} - -void nvcuda_release(nvcuda_handle_t h) { - LOG(h.verbose, "releasing cuda driver library\n"); - UNLOAD_LIBRARY(h.handle); - // TODO and other context release logic? - h.handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_nvcuda.h b/discover/gpu_info_nvcuda.h deleted file mode 100644 index ef2fe8a30..000000000 --- a/discover/gpu_info_nvcuda.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVCUDA_H__ -#define __GPU_INFO_NVCUDA_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum cudaError_enum { - CUDA_SUCCESS = 0, - CUDA_ERROR_INVALID_VALUE = 1, - CUDA_ERROR_OUT_OF_MEMORY = 2, - CUDA_ERROR_NOT_INITIALIZED = 3, - CUDA_ERROR_INSUFFICIENT_DRIVER = 35, - CUDA_ERROR_NO_DEVICE = 100, - CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803, - CUDA_ERROR_UNKNOWN = 999, - // Other values omitted for now... -} CUresult; - -typedef enum CUdevice_attribute_enum { - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, - - // TODO - not yet wired up but may be useful for Jetson or other - // integrated GPU scenarios with shared memory - CU_DEVICE_ATTRIBUTE_INTEGRATED = 18 - -} CUdevice_attribute; - -typedef void *nvcudaDevice_t; // Opaque is sufficient -typedef struct nvcudaMemory_st { - uint64_t total; - uint64_t free; -} nvcudaMemory_t; - -typedef struct nvcudaDriverVersion { - int major; - int minor; -} nvcudaDriverVersion_t; - -typedef struct CUuuid_st { - unsigned char bytes[16]; -} CUuuid; - -typedef int CUdevice; -typedef void* CUcontext; - -typedef struct nvcuda_handle { - void *handle; - uint16_t verbose; - int driver_major; - int driver_minor; - CUresult (*cuInit)(unsigned int Flags); - CUresult (*cuDriverGetVersion)(int *driverVersion); - CUresult (*cuDeviceGetCount)(int *); - CUresult (*cuDeviceGet)(CUdevice* device, int ordinal); - CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev); - CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2 - CUresult (*cuDeviceGetName)(char *name, int len, CUdevice dev); - - // Context specific aspects - CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev); - CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total); - CUresult (*cuCtxDestroy)(CUcontext ctx); -} nvcuda_handle_t; - -typedef struct nvcuda_init_resp { - char *err; // If err is non-null handle is invalid - nvcuda_handle_t ch; - int num_devices; - CUresult cudaErr; -} nvcuda_init_resp_t; - -void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp); -void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp); -void nvcuda_get_free(nvcuda_handle_t ch, int device_id, uint64_t *free, uint64_t *total); -void nvcuda_release(nvcuda_handle_t ch); - -#endif // __GPU_INFO_NVCUDA_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_nvml.c b/discover/gpu_info_nvml.c deleted file mode 100644 index 342a3aa4b..000000000 --- a/discover/gpu_info_nvml.c +++ /dev/null @@ -1,104 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include - -#include "gpu_info_nvml.h" - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { - nvmlReturn_t ret; - resp->err = NULL; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2}, - {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown}, - {"nvmlDeviceGetHandleByUUID", (void *)&resp->ch.nvmlDeviceGetHandleByUUID}, - {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvml_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - // LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - resp->ch.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.nvmlInit_v2)(); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "nvml vram init failure: %d", ret); - resp->err = strdup(buf); - return; - } -} - - -void nvml_get_free(nvml_handle_t h, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used) { - nvmlDevice_t device; - nvmlMemory_t memInfo = {0}; - nvmlReturn_t ret; - ret = (*h.nvmlDeviceGetHandleByUUID)((const char *)(uuid), &device); - if (ret != NVML_SUCCESS) { - LOG(1, "unable to get device handle %s: %d", uuid, ret); - *free = 0; - return; - } - - ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo); - if (ret != NVML_SUCCESS) { - LOG(1, "device memory info lookup failure %s: %d", uuid, ret); - *free = 0; - return; - } - *free = memInfo.free; - *total = memInfo.total; - *used = memInfo.used; -} - - -void nvml_release(nvml_handle_t h) { - LOG(h.verbose, "releasing nvml library\n"); - nvmlReturn_t ret; - ret = (*h.nvmlShutdown)(); - if (ret != NVML_SUCCESS) { - LOG(1, "error during nvmlShutdown %d", ret); - } - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ \ No newline at end of file diff --git a/discover/gpu_info_nvml.h b/discover/gpu_info_nvml.h deleted file mode 100644 index 908802337..000000000 --- a/discover/gpu_info_nvml.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVML_H__ -#define __GPU_INFO_NVML_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum nvmlReturn_enum { - NVML_SUCCESS = 0, - // Other values omitted for now... -} nvmlReturn_t; -typedef void *nvmlDevice_t; // Opaque is sufficient -typedef struct nvmlMemory_st { - unsigned long long total; - unsigned long long free; - unsigned long long used; -} nvmlMemory_t; - -typedef enum nvmlBrandType_enum -{ - NVML_BRAND_UNKNOWN = 0, -} nvmlBrandType_t; - -typedef struct nvml_handle { - void *handle; - uint16_t verbose; - nvmlReturn_t (*nvmlInit_v2)(void); - nvmlReturn_t (*nvmlShutdown)(void); - nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); - nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); -} nvml_handle_t; - -typedef struct nvml_init_resp { - char *err; // If err is non-null handle is invalid - nvml_handle_t ch; -} nvml_init_resp_t; - -typedef struct nvml_compute_capability { - char *err; - int major; - int minor; -} nvml_compute_capability_t; - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp); -void nvml_get_free(nvml_handle_t ch, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used); -void nvml_release(nvml_handle_t ch); - -#endif // __GPU_INFO_NVML_H__ -#endif // __APPLE__ \ No newline at end of file diff --git a/discover/gpu_info_oneapi.c b/discover/gpu_info_oneapi.c deleted file mode 100644 index 3ff708ea2..000000000 --- a/discover/gpu_info_oneapi.c +++ /dev/null @@ -1,259 +0,0 @@ -#ifndef __APPLE__ - -#include "gpu_info_oneapi.h" - -#include - -void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) { - ze_result_t ret; - resp->err = NULL; - resp->oh.devices = NULL; - resp->oh.num_devices = NULL; - resp->oh.drivers = NULL; - resp->oh.num_drivers = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i, d; - struct lookup { - char *s; - void **p; - } l[] = { - {"zesInit", (void *)&resp->oh.zesInit}, - {"zesDriverGet", (void *)&resp->oh.zesDriverGet}, - {"zesDeviceGet", (void *)&resp->oh.zesDeviceGet}, - {"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties}, - {"zesDeviceEnumMemoryModules", - (void *)&resp->oh.zesDeviceEnumMemoryModules}, - {"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties}, - {"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState}, - {NULL, NULL}, - }; - - resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY); - if (!resp->oh.handle) { - char *msg = LOAD_ERR(); - snprintf(buf, buflen, - "Unable to load %s library to query for Intel GPUs: %s\n", - oneapi_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->oh.verbose, - "wiring Level-Zero management library functions in %s\n", - oneapi_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s); - if (!*(l[i].p)) { - resp->oh.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->oh.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->oh.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - LOG(resp->oh.verbose, "calling zesInit\n"); - - ret = (*resp->oh.zesInit)(0); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesInit err: %x\n", ret); - snprintf(buf, buflen, "oneapi vram init failure: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - - LOG(resp->oh.verbose, "calling zesDriverGet\n"); - ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get driver count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers); - resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t)); - resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t)); - memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t)); - resp->oh.devices = - malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *)); - ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get driver count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - - for (d = 0; d < resp->oh.num_drivers; d++) { - LOG(resp->oh.verbose, "calling zesDeviceGet count %d: %p\n", d, resp->oh.drivers[d]); - ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], - &resp->oh.num_devices[d], NULL); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get device count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - resp->oh.devices[d] = - malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t)); - ret = (*resp->oh.zesDeviceGet)( - resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get device count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - } - - return; -} - -void oneapi_check_vram(oneapi_handle_t h, int driver, int device, - mem_info_t *resp) { - ze_result_t ret; - resp->err = NULL; - uint64_t totalMem = 0; - uint64_t usedMem = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i, d, m; - - if (h.handle == NULL) { - resp->err = strdup("Level-Zero handle not initialized"); - return; - } - - if (driver > h.num_drivers || device > h.num_devices[driver]) { - resp->err = strdup("driver of device index out of bounds"); - return; - } - - resp->total = 0; - resp->free = 0; - - zes_device_ext_properties_t ext_props; - ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES; - ext_props.pNext = NULL; - - zes_device_properties_t props; - props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES; - props.pNext = &ext_props; - - ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get device properties: %d", ret); - resp->err = strdup(buf); - return; - } - - snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", props.modelName); - - // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax - // (this is probably wrong...) - // TODO - the driver isn't included - what if there are multiple drivers? - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device); - - if (h.verbose) { - // When in verbose mode, report more information about - // the card we discover. - LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device, - props.modelName); - LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device, - props.brandName); - LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device, - props.vendorName); - LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device, - props.serialNumber); - LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device, - props.boardNumber); - } - - // TODO - // Compute Capability equivalent in resp->major, resp->minor, resp->patch - - uint32_t memCount = 0; - ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, - NULL); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x", - ret); - resp->err = strdup(buf); - return; - } - - LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount); - - zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t)); - (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems); - - for (m = 0; m < memCount; m++) { - zes_mem_state_t state; - state.stype = ZES_STRUCTURE_TYPE_MEM_STATE; - state.pNext = NULL; - ret = (*h.zesMemoryGetState)(mems[m], &state); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get memory state: %x", ret); - resp->err = strdup(buf); - free(mems); - return; - } - - resp->total += state.size; - resp->free += state.free; - } - - free(mems); -} - -void oneapi_release(oneapi_handle_t h) { - int d; - LOG(h.verbose, "releasing oneapi library\n"); - for (d = 0; d < h.num_drivers; d++) { - if (h.devices != NULL && h.devices[d] != NULL) { - free(h.devices[d]); - } - } - if (h.devices != NULL) { - free(h.devices); - h.devices = NULL; - } - if (h.num_devices != NULL) { - free(h.num_devices); - h.num_devices = NULL; - } - if (h.drivers != NULL) { - free(h.drivers); - h.drivers = NULL; - } - h.num_drivers = 0; - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -int oneapi_get_device_count(oneapi_handle_t h, int driver) { - if (h.handle == NULL || h.num_devices == NULL) { - return 0; - } - if (driver > h.num_drivers) { - return 0; - } - return (int)h.num_devices[driver]; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_oneapi.h b/discover/gpu_info_oneapi.h deleted file mode 100644 index 97fcecd9c..000000000 --- a/discover/gpu_info_oneapi.h +++ /dev/null @@ -1,203 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_ONEAPI_H__ -#define __GPU_INFO_ONEAPI_H__ -#include "gpu_info.h" - -#define ZE_MAX_DEVICE_NAME 256 -#define ZE_MAX_DEVICE_UUID_SIZE 16 -#define ZES_STRING_PROPERTY_SIZE 64 -#define ZE_BIT(_i) (1 << _i) - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum ze_result_t { - ZE_RESULT_SUCCESS = 0, - // Other values omitted for now... -} ze_result_t; - -typedef uint8_t ze_bool_t; -typedef struct _zes_driver_handle_t *zes_driver_handle_t; -typedef struct _zes_device_handle_t *zes_device_handle_t; -typedef struct _zes_mem_handle_t *zes_mem_handle_t; - -typedef enum _ze_structure_type_t { - ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff -} ze_structure_type_t; - -typedef enum _zes_structure_type_t { - ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1, - ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb, - ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e, - ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d, - ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_structure_type_t; - -typedef enum _zes_mem_type_t { - ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_mem_type_t; - -typedef enum _zes_mem_loc_t { - ZES_MEM_LOC_SYSTEM = 0, - ZES_MEM_LOC_DEVICE = 1, - ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff -} zes_mem_loc_t; - -typedef enum _zes_mem_health_t { - ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff -} zes_mem_health_t; - -typedef struct _ze_device_uuid_t { - uint8_t id[ZE_MAX_DEVICE_UUID_SIZE]; -} ze_device_uuid_t; - -typedef struct _zes_uuid_t { - uint8_t id[ZE_MAX_DEVICE_UUID_SIZE]; -} zes_uuid_t; - -typedef enum _ze_device_type_t { - ZE_DEVICE_TYPE_GPU = 1, - ZE_DEVICE_TYPE_CPU = 2, - ZE_DEVICE_TYPE_FPGA = 3, - ZE_DEVICE_TYPE_MCA = 4, - ZE_DEVICE_TYPE_VPU = 5, - ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff -} ze_device_type_t; - -typedef enum _zes_device_type_t { - ZES_DEVICE_TYPE_GPU = 1, - ZES_DEVICE_TYPE_CPU = 2, - ZES_DEVICE_TYPE_FPGA = 3, - ZES_DEVICE_TYPE_MCA = 4, - ZES_DEVICE_TYPE_VPU = 5, - ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_device_type_t; - -typedef uint32_t ze_device_property_flags_t; -typedef enum _ze_device_property_flag_t { - ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0), - ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1), - ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2), - ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3), - ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff -} ze_device_property_flag_t; - -typedef uint32_t zes_device_property_flags_t; -typedef enum _zes_device_property_flag_t { - ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0), - ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1), - ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2), - ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3), - ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff -} zes_device_property_flag_t; - -typedef struct _ze_device_properties_t { - ze_structure_type_t stype; - void *pNext; - ze_device_type_t type; - uint32_t vendorId; - uint32_t deviceId; - ze_device_property_flags_t flags; - uint32_t subdeviceId; - uint32_t coreClockRate; - uint64_t maxMemAllocSize; - uint32_t maxHardwareContexts; - uint32_t maxCommandQueuePriority; - uint32_t numThreadsPerEU; - uint32_t physicalEUSimdWidth; - uint32_t numEUsPerSubslice; - uint32_t numSubslicesPerSlice; - uint32_t numSlices; - uint64_t timerResolution; - uint32_t timestampValidBits; - uint32_t kernelTimestampValidBits; - ze_device_uuid_t uuid; - char name[ZE_MAX_DEVICE_NAME]; -} ze_device_properties_t; - -typedef struct _zes_device_properties_t { - zes_structure_type_t stype; - void *pNext; - ze_device_properties_t core; - uint32_t numSubdevices; - char serialNumber[ZES_STRING_PROPERTY_SIZE]; - char boardNumber[ZES_STRING_PROPERTY_SIZE]; - char brandName[ZES_STRING_PROPERTY_SIZE]; - char modelName[ZES_STRING_PROPERTY_SIZE]; - char vendorName[ZES_STRING_PROPERTY_SIZE]; - char driverVersion[ZES_STRING_PROPERTY_SIZE]; -} zes_device_properties_t; - -typedef struct _zes_device_ext_properties_t { - zes_structure_type_t stype; - void *pNext; - zes_uuid_t uuid; - zes_device_type_t type; - zes_device_property_flags_t flags; -} zes_device_ext_properties_t; - -typedef struct _zes_mem_properties_t { - zes_structure_type_t stype; - void *pNext; - zes_mem_type_t type; - ze_bool_t onSubdevice; - uint32_t subdeviceId; - zes_mem_loc_t location; - uint64_t physicalSize; - int32_t busWidth; - int32_t numChannels; -} zes_mem_properties_t; - -typedef struct _zes_mem_state_t { - zes_structure_type_t stype; - const void *pNext; - zes_mem_health_t health; - uint64_t free; - uint64_t size; -} zes_mem_state_t; - -typedef struct oneapi_handle { - void *handle; - uint16_t verbose; - - uint32_t num_drivers; - zes_driver_handle_t *drivers; - uint32_t *num_devices; - zes_device_handle_t **devices; - - // TODO Driver major, minor information - // int driver_major; - // int driver_minor; - - ze_result_t (*zesInit)(int); - ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers); - ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount, - zes_device_handle_t *phDevices); - ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice, - zes_device_properties_t *pProperties); - ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice, - uint32_t *pCount, - zes_mem_handle_t *phMemory); - ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory, - zes_mem_properties_t *pProperties); - ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory, - zes_mem_state_t *pState); - -} oneapi_handle_t; - -typedef struct oneapi_init_resp { - char *err; // If err is non-null handle is invalid - oneapi_handle_t oh; -} oneapi_init_resp_t; - -typedef struct oneapi_version_resp { - ze_result_t status; - char *str; // Contains version or error string if status != 0 -} oneapi_version_resp_t; - -void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp); -void oneapi_check_vram(oneapi_handle_t h, int driver, int device, - mem_info_t *resp); -void oneapi_release(oneapi_handle_t h); -int oneapi_get_device_count(oneapi_handle_t h, int driver); - -#endif // __GPU_INFO_INTEL_H__ -#endif // __APPLE__ diff --git a/discover/gpu_oneapi.go b/discover/gpu_oneapi.go deleted file mode 100644 index 77941f5b3..000000000 --- a/discover/gpu_oneapi.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "log/slog" - "strings" -) - -func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "oneapi" { - // TODO shouldn't happen if things are wired correctly... - slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library) - continue - } - ids = append(ids, info.ID) - } - return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",") -} diff --git a/discover/gpu_test.go b/discover/gpu_test.go deleted file mode 100644 index 0c6ef7bad..000000000 --- a/discover/gpu_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package discover - -import ( - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestBasicGetGPUInfo(t *testing.T) { - info := GetGPUInfo() - assert.NotEmpty(t, len(info)) - assert.Contains(t, "cuda rocm cpu metal", info[0].Library) - if info[0].Library != "cpu" { - assert.Greater(t, info[0].TotalMemory, uint64(0)) - assert.Greater(t, info[0].FreeMemory, uint64(0)) - } -} - -func TestCPUMemInfo(t *testing.T) { - info, err := GetCPUMem() - require.NoError(t, err) - switch runtime.GOOS { - case "darwin": - t.Skip("CPU memory not populated on darwin") - case "linux", "windows": - assert.Greater(t, info.TotalMemory, uint64(0)) - assert.Greater(t, info.FreeMemory, uint64(0)) - default: - return - } -} - -func TestByLibrary(t *testing.T) { - type testCase struct { - input []GpuInfo - expect int - } - - testCases := map[string]*testCase{ - "empty": {input: []GpuInfo{}, expect: 0}, - "cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1}, - "cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2}, - "cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2}, - "cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2}, - "cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3}, - } - - for k, v := range testCases { - t.Run(k, func(t *testing.T) { - resp := (GpuInfoList)(v.input).ByLibrary() - if len(resp) != v.expect { - t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp) - } - }) - } -} - -// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected diff --git a/discover/runner.go b/discover/runner.go new file mode 100644 index 000000000..8071111fa --- /dev/null +++ b/discover/runner.go @@ -0,0 +1,542 @@ +package discover + +// Runner based GPU discovery + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "math/rand" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" +) + +var ( + deviceMu sync.Mutex + devices []ml.DeviceInfo + libDirs map[string]struct{} + rocmDir string + exe string + bootstrapped bool +) + +func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.DeviceInfo { + deviceMu.Lock() + defer deviceMu.Unlock() + startDiscovery := time.Now() + msg := "overall device VRAM discovery took" + defer func() { + slog.Debug(msg, "duration", time.Since(startDiscovery)) + }() + + if !bootstrapped { + msg = "GPU bootstrap discovery took" + libDirs = make(map[string]struct{}) + var err error + exe, err = os.Executable() + if err != nil { + slog.Error("unable to lookup executable path", "error", err) + return nil + } + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + files, err := filepath.Glob(filepath.Join(LibOllamaPath, "*", "*ggml-*")) + if err != nil { + slog.Debug("unable to lookup runner library directories", "error", err) + } + for _, file := range files { + libDirs[filepath.Dir(file)] = struct{}{} + } + + // Our current packaging model places ggml-hip in the main directory + // but keeps rocm in an isolated directory. We have to add it to + // the [LD_LIBRARY_]PATH so ggml-hip will load properly + rocmDir = filepath.Join(LibOllamaPath, "rocm") + if _, err := os.Stat(rocmDir); err != nil { + rocmDir = "" + } + + if len(libDirs) == 0 { + libDirs[""] = struct{}{} + } + + slog.Info("discovering available GPUs...") + + // For our initial discovery pass, we gather all the known GPUs through + // all the libraries that were detected. This pass may include GPUs that + // are enumerated, but not actually supported. + // We run this in serial to avoid potentially initializing a GPU multiple + // times concurrently leading to memory contention + for dir := range libDirs { + var dirs []string + if dir == "" { + dirs = []string{LibOllamaPath} + } else { + dirs = []string{LibOllamaPath, dir} + } + // Typically bootstrapping takes < 1s, but on some systems, with devices + // in low power/idle mode, initialization can take multiple seconds. We + // set a long timeout just for bootstrap discovery to reduce the chance + // of giving up too quickly + ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // For this pass, we retain duplicates in case any are incompatible with some libraries + devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...) + } + + // In the second pass, we more deeply initialize the GPUs to weed out devices that + // aren't supported by a given library. We run this phase in parallel to speed up discovery. + slog.Debug("filtering out unsupported or overlapping GPU library combinations", "count", len(devices)) + ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + var wg sync.WaitGroup + needsDelete := make([]bool, len(devices)) + supportedMu := sync.Mutex{} + supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index + for i := range devices { + libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1] + if devices[i].Library == "Metal" { + continue + } + slog.Debug("verifying GPU is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID) + wg.Add(1) + go func(i int) { + defer wg.Done() + var envVar string + if devices[i].Library == "ROCm" { + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES" + } else { + envVar = "ROCR_VISIBLE_DEVICES" + } + } else { + envVar = "CUDA_VISIBLE_DEVICES" + } + + extraEnvs := []string{ + "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs + envVar + "=" + devices[i].ID, // Filter to just this one GPU + } + if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { + needsDelete[i] = true + } else { + supportedMu.Lock() + if _, ok := supported[devices[i].Library]; !ok { + supported[devices[i].Library] = make(map[string]map[string]int) + } + if _, ok := supported[devices[i].Library][libDir]; !ok { + supported[devices[i].Library][libDir] = make(map[string]int) + } + supported[devices[i].Library][libDir][devices[i].ID] = i + supportedMu.Unlock() + } + }(i) + } + wg.Wait() + logutil.Trace("supported GPU library combinations", "supported", supported) + + // Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible + filterOverlapByLibrary(supported, needsDelete) + + // TODO if we ever support multiple ROCm library versions this algorithm will need to be adjusted to keep the rocmID numeric value correct + rocmID := 0 + for i := 0; i < len(needsDelete); i++ { + if needsDelete[i] { + logutil.Trace("removing unsupported or overlapping GPU combination", "libDir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID) + devices = append(devices[:i], devices[i+1:]...) + needsDelete = append(needsDelete[:i], needsDelete[i+1:]...) + i-- + } else if devices[i].Library == "ROCm" { + if _, err := strconv.Atoi(devices[i].ID); err == nil { + // Replace the numeric ID with the post-filtered IDs + devices[i].FilteredID = devices[i].ID + devices[i].ID = strconv.Itoa(rocmID) + } + rocmID++ + } + } + + // Now filter out any overlap with different libraries (favor CUDA/ROCm over others) + for i := 0; i < len(devices); i++ { + for j := i + 1; j < len(devices); j++ { + // For this pass, we only drop exact duplicates + switch devices[i].Compare(devices[j]) { + case ml.SameBackendDevice: + // Same library and device, skip it + devices = append(devices[:j], devices[j+1:]...) + j-- + continue + case ml.DuplicateDevice: + // Different library, choose based on priority + var droppedDevice ml.DeviceInfo + if devices[i].Library == "CUDA" || devices[i].Library == "ROCm" { + droppedDevice = devices[j] + } else { + droppedDevice = devices[i] + devices[i] = devices[j] + } + devices = append(devices[:j], devices[j+1:]...) + j-- + + typeStr := "discrete" + if droppedDevice.Integrated { + typeStr = "iGPU" + } + slog.Debug("dropping duplicate device", + "id", droppedDevice.ID, + "library", droppedDevice.Library, + "compute", droppedDevice.Compute(), + "name", droppedDevice.Name, + "description", droppedDevice.Description, + "libdirs", strings.Join(droppedDevice.LibraryPath, ","), + "driver", droppedDevice.Driver(), + "pci_id", droppedDevice.PCIID, + "type", typeStr, + "total", format.HumanBytes2(droppedDevice.TotalMemory), + "available", format.HumanBytes2(droppedDevice.FreeMemory), + ) + continue + } + } + } + + // Reset the libDirs to what we actually wind up using for future refreshes + libDirs = make(map[string]struct{}) + for _, dev := range devices { + dir := dev.LibraryPath[len(dev.LibraryPath)-1] + if dir != LibOllamaPath { + libDirs[dir] = struct{}{} + } + } + if len(libDirs) == 0 { + libDirs[""] = struct{}{} + } + + bootstrapped = true + } else { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + // metal never updates free VRAM + return devices + } + + slog.Debug("refreshing free memory") + updated := make([]bool, len(devices)) + allDone := func() bool { + allDone := true + for _, done := range updated { + if !done { + allDone = false + break + } + } + return allDone + } + + // First try to use existing runners to refresh VRAM since they're already + // active on GPU(s) + for _, runner := range runners { + if runner == nil { + continue + } + deviceIDs := runner.GetActiveDeviceIDs() + if len(deviceIDs) == 0 { + // Skip this runner since it doesn't have active GPU devices + continue + } + + // Check to see if this runner is active on any devices that need a refresh + skip := true + devCheck: + for _, dev := range deviceIDs { + for i := range devices { + if dev == devices[i].DeviceID { + if !updated[i] { + skip = false + break devCheck + } + } + } + } + if skip { + continue + } + + // Typical refresh on existing runner is ~500ms but allow longer if the system + // is under stress before giving up and using stale data. + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + start := time.Now() + updatedDevices := runner.GetDeviceInfos(ctx) + slog.Debug("existing runner discovery took", "duration", time.Since(start)) + for _, u := range updatedDevices { + for i := range devices { + if u.DeviceID == devices[i].DeviceID { + updated[i] = true + devices[i].FreeMemory = u.FreeMemory + break + } + } + } + // Short circuit if we've updated all the devices + if allDone() { + break + } + } + if !allDone() { + slog.Debug("unable to refresh all GPUs with existing runners, performing bootstrap discovery") + + // Bootstrapping may take longer in some cases (AMD windows), but we + // would rather use stale free data to get the model running sooner + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + for dir := range libDirs { + updatedDevices := bootstrapDevices(ctx, []string{LibOllamaPath, dir}, nil) + for _, u := range updatedDevices { + for i := range devices { + if u.DeviceID == devices[i].DeviceID { + updated[i] = true + devices[i].FreeMemory = u.FreeMemory + break + } + } + // TODO - consider evaluating if new devices have appeared (e.g. hotplug) + } + if allDone() { + break + } + } + if !allDone() { + slog.Warn("unable to refresh free memory, using old values") + } + } + } + + return devices +} + +func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) { + // For multi-GPU systems, use the newest version that supports all the GPUs + for _, byLibDirs := range supported { + libDirs := make([]string, 0, len(byLibDirs)) + for libDir := range byLibDirs { + libDirs = append(libDirs, libDir) + } + sort.Sort(sort.Reverse(sort.StringSlice(libDirs))) + anyMissing := false + var newest string + for _, newest = range libDirs { + for _, libDir := range libDirs { + if libDir == newest { + continue + } + if len(byLibDirs[newest]) != len(byLibDirs[libDir]) { + anyMissing = true + break + } + for dev := range byLibDirs[newest] { + if _, found := byLibDirs[libDir][dev]; !found { + anyMissing = true + break + } + } + } + if !anyMissing { + break + } + } + // Now we can mark overlaps for deletion + for _, libDir := range libDirs { + if libDir == newest { + continue + } + for dev, i := range byLibDirs[libDir] { + if _, found := byLibDirs[newest][dev]; found { + needsDelete[i] = true + } + } + } + } +} + +type bootstrapRunner struct { + port int + cmd *exec.Cmd +} + +func (r *bootstrapRunner) GetPort() int { + return r.port +} + +func (r *bootstrapRunner) HasExited() bool { + if r.cmd != nil && r.cmd.ProcessState != nil { + return true + } + return false +} + +func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo { + // TODO DRY out with llm/server.go + slog.Debug("spawing runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + start := time.Now() + defer func() { + slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + }() + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + slog.Debug("ResolveTCPAddr failed, using random port") + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + params := []string{"runner", "--ollama-engine", "--port", strconv.Itoa(port)} + var pathEnv string + switch runtime.GOOS { + case "windows": + pathEnv = "PATH" + case "darwin": + pathEnv = "DYLD_LIBRARY_PATH" + default: + pathEnv = "LD_LIBRARY_PATH" + } + libraryPaths := append([]string{LibOllamaPath}, ollamaLibDirs...) + if rocmDir != "" { + libraryPaths = append(libraryPaths, rocmDir) + } + // Note: we always put our dependency paths first + // since these are the exact version we compiled/linked against + if libraryPath, ok := os.LookupEnv(pathEnv); ok { + libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...) + } + + cmd := exec.Command(exe, params...) + cmd.Env = os.Environ() + if envconfig.LogLevel() == logutil.LevelTrace { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + } + // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored + cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator))) + pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) + pathNeeded := true + extraDone := make([]bool, len(extraEnvs)) + for i := range cmd.Env { + cmp := strings.SplitN(cmd.Env[i], "=", 2) + if strings.EqualFold(cmp[0], pathEnv) { + cmd.Env[i] = pathEnv + "=" + pathEnvVal + pathNeeded = false + } else { + for j := range extraEnvs { + if extraDone[j] { + continue + } + extra := strings.SplitN(extraEnvs[j], "=", 2) + if cmp[0] == extra[0] { + cmd.Env[i] = extraEnvs[j] + extraDone[j] = true + } + } + } + } + if pathNeeded { + cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal) + } + for i := range extraDone { + if !extraDone[i] { + cmd.Env = append(cmd.Env, extraEnvs[i]) + } + } + logutil.Trace("starting runner for device discovery", "env", cmd.Env, "cmd", cmd) + if err := cmd.Start(); err != nil { + slog.Warn("unable to start discovery subprocess", "cmd", cmd, "error", err) + return nil + } + go func() { + cmd.Wait() // exit status ignored + }() + + defer cmd.Process.Kill() + devices, err := GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd}) + if err != nil { + if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 { + // Expected during bootstrapping while we filter out unsupported AMD GPUs + logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode()) + } else { + slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err) + } + } + logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) + return devices +} + +func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceInfo, error) { + var moreDevices []ml.DeviceInfo + port := runner.GetPort() + tick := time.Tick(10 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("failed to finish discovery before timeout") + case <-tick: + r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + r.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + // slog.Warn("failed to send request", "error", err) + if runner.HasExited() { + return nil, fmt.Errorf("runner crashed") + } + continue + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // old runner, fall back to bootstrapping model + return nil, fmt.Errorf("llamarunner free vram reporting not supported") + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + slog.Warn("failed to read response", "error", err) + continue + } + if resp.StatusCode != 200 { + logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body) + return nil, fmt.Errorf("runner error: %s", string(body)) + } + + if err := json.Unmarshal(body, &moreDevices); err != nil { + slog.Warn("unmarshal encode response", "error", err) + continue + } + return moreDevices, nil + } + } +} diff --git a/discover/runner_test.go b/discover/runner_test.go new file mode 100644 index 000000000..9ea190461 --- /dev/null +++ b/discover/runner_test.go @@ -0,0 +1,108 @@ +package discover + +import ( + "testing" + + "github.com/ollama/ollama/app/lifecycle" +) + +func init() { + lifecycle.InitLogging() +} + +func TestFilterOverlapByLibrary(t *testing.T) { + type testcase struct { + name string + inp map[string]map[string]map[string]int + exp []bool + } + for _, tc := range []testcase{ + { + name: "empty", + inp: map[string]map[string]map[string]int{}, + exp: []bool{}, // needs deletion + }, + { + name: "single no overlap", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + }, + }, + }, + exp: []bool{false}, + }, + { + name: "100% overlap pick 2nd", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1, + }, + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 2, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 3, + }, + }, + }, + exp: []bool{true, true, false, false}, + }, + { + name: "100% overlap pick 1st", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1, + }, + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 2, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 3, + }, + }, + }, + exp: []bool{false, false, true, true}, + }, + { + name: "partial overlap pick older", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + }, + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 1, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 2, + }, + }, + }, + exp: []bool{true, false, false}, + }, + { + name: "no overlap", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + }, + "cuda_v12": { + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1, + }, + }, + }, + exp: []bool{false, false}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + needsDelete := make([]bool, len(tc.exp)) + filterOverlapByLibrary(tc.inp, needsDelete) + for i, exp := range tc.exp { + if needsDelete[i] != exp { + t.Fatalf("expected: %v\ngot: %v", tc.exp, needsDelete) + } + } + }) + } +} diff --git a/discover/types.go b/discover/types.go index 13a030fd5..718809f44 100644 --- a/discover/types.go +++ b/discover/types.go @@ -1,10 +1,14 @@ package discover import ( - "fmt" + "context" "log/slog" + "path/filepath" + "runtime" + "strings" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" ) type memInfo struct { @@ -15,8 +19,8 @@ type memInfo struct { // Beginning of an `ollama info` command type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? + ml.DeviceID memInfo - Library string `json:"library,omitempty"` // Optional variant to select (e.g. versions, cpu feature flags) Variant string `json:"variant"` @@ -27,18 +31,15 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly DependencyPath []string `json:"lib_path,omitempty"` - // Extra environment variables specific to the GPU as list of [key,value] - EnvWorkarounds [][2]string `json:"envs,omitempty"` - // Set to true if we can NOT reliably discover FreeMemory. A value of true indicates // the FreeMemory is best effort, and may over or under report actual memory usage // False indicates FreeMemory can generally be trusted on this GPU UnreliableFreeMemory bool // GPU information - ID string `json:"gpu_id"` // string to use for selection of this specific GPU - Name string `json:"name"` // user friendly name if available - Compute string `json:"compute"` // Compute Capability or gfx + filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices + Name string `json:"name"` // user friendly name if available + Compute string `json:"compute"` // Compute Capability or gfx // Driver Information - TODO no need to put this on each GPU DriverMajor int `json:"driver_major,omitempty"` @@ -69,37 +70,8 @@ type CPU struct { ThreadCount int } -type CudaGPUInfo struct { - GpuInfo - OSOverhead uint64 // Memory overhead between the driver library and management library - index int //nolint:unused,nolintlint - computeMajor int //nolint:unused,nolintlint - computeMinor int //nolint:unused,nolintlint -} -type CudaGPUInfoList []CudaGPUInfo - -type RocmGPUInfo struct { - GpuInfo - usedFilepath string //nolint:unused,nolintlint - index int //nolint:unused,nolintlint -} -type RocmGPUInfoList []RocmGPUInfo - -type OneapiGPUInfo struct { - GpuInfo - driverIndex int //nolint:unused,nolintlint - gpuIndex int //nolint:unused,nolintlint -} -type OneapiGPUInfoList []OneapiGPUInfo - type GpuInfoList []GpuInfo -type UnsupportedGPUInfo struct { - GpuInfo - Reason string `json:"reason"` -} - -// Split up the set of gpu info's by Library and variant func (l GpuInfoList) ByLibrary() []GpuInfoList { resp := []GpuInfoList{} libs := []string{} @@ -124,18 +96,47 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList { return resp } -// Report the GPU information into the log an Info level -func (l GpuInfoList) LogDetails() { - for _, g := range l { +func LogDetails(devices []ml.DeviceInfo) { + for _, dev := range devices { + var libs []string + for _, dir := range dev.LibraryPath { + if strings.Contains(dir, filepath.Join("lib", "ollama")) { + libs = append(libs, filepath.Base(dir)) + } + } + typeStr := "discrete" + if dev.Integrated { + typeStr = "iGPU" + } slog.Info("inference compute", - "id", g.ID, - "library", g.Library, - "variant", g.Variant, - "compute", g.Compute, - "driver", fmt.Sprintf("%d.%d", g.DriverMajor, g.DriverMinor), - "name", g.Name, - "total", format.HumanBytes2(g.TotalMemory), - "available", format.HumanBytes2(g.FreeMemory), + "id", dev.ID, + "library", dev.Library, + "compute", dev.Compute(), + "name", dev.Name, + "description", dev.Description, + "libdirs", strings.Join(libs, ","), + "driver", dev.Driver(), + "pci_id", dev.PCIID, + "type", typeStr, + "total", format.HumanBytes2(dev.TotalMemory), + "available", format.HumanBytes2(dev.FreeMemory), + ) + } + // CPU inference + if len(devices) == 0 { + dev, _ := GetCPUMem() + slog.Info("inference compute", + "id", "cpu", + "library", "cpu", + "compute", "", + "name", "cpu", + "description", "cpu", + "libdirs", "ollama", + "driver", "", + "pci_id", "", + "type", "", + "total", format.HumanBytes2(dev.TotalMemory), + "available", format.HumanBytes2(dev.FreeMemory), ) } } @@ -148,16 +149,15 @@ func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory } type SystemInfo struct { - System CPUInfo `json:"system"` - GPUs []GpuInfo `json:"gpus"` - UnsupportedGPUs []UnsupportedGPUInfo `json:"unsupported_gpus"` - DiscoveryErrors []string `json:"discovery_errors"` + System CPUInfo `json:"system"` + GPUs []GpuInfo `json:"gpus"` } // Return the optimal number of threads to use for inference func (si SystemInfo) GetOptimalThreadCount() int { if len(si.System.CPUs) == 0 { - return 0 + // Fall back to Go's num CPU + return runtime.NumCPU() } coreCount := 0 @@ -172,9 +172,9 @@ func (si SystemInfo) GetOptimalThreadCount() int { func (l GpuInfoList) FlashAttentionSupported() bool { for _, gpu := range l { supportsFA := gpu.Library == "cpu" || - gpu.Library == "metal" || - (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || - gpu.Library == "rocm" + gpu.Name == "Metal" || gpu.Library == "Metal" || + (gpu.Library == "CUDA" && gpu.DriverMajor >= 7) || + gpu.Library == "ROCm" if !supportsFA { return false @@ -182,3 +182,31 @@ func (l GpuInfoList) FlashAttentionSupported() bool { } return true } + +type BaseRunner interface { + // GetPort returns the localhost port number the runner is running on + GetPort() int + + // HasExited indicates if the runner is no longer running. This can be used during + // bootstrap to detect if a given filtered device is incompatible and triggered an assert + HasExited() bool +} + +type RunnerDiscovery interface { + BaseRunner + + // GetDeviceInfos will perform a query of the underlying device libraries + // for device identification and free VRAM information + // During bootstrap scenarios, this routine may take seconds to complete + GetDeviceInfos(ctx context.Context) []ml.DeviceInfo +} + +type FilteredRunnerDiscovery interface { + RunnerDiscovery + + // GetActiveDeviceIDs returns the filtered set of devices actively in + // use by this runner for running models. If the runner is a bootstrap runner, no devices + // will be active yet so no device IDs are returned. + // This routine will not query the underlying device and will return immediately + GetActiveDeviceIDs() []ml.DeviceID +} diff --git a/docs/api.md b/docs/api.md index f11d59ed1..f47af63c6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1708,6 +1708,7 @@ Advanced parameters: - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) +- `dimensions`: number of dimensions for the embedding ### Examples diff --git a/docs/cloud.md b/docs/cloud.md new file mode 100644 index 000000000..300e6f5e0 --- /dev/null +++ b/docs/cloud.md @@ -0,0 +1,40 @@ +# Cloud + +| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud). + +## Cloud Models + +[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn’t fit on a personal computer. + +Ollama currently supports the following cloud models, with more coming soon: + +- `gpt-oss:20b-cloud` +- `gpt-oss:120b-cloud` +- `deepseek-v3.1:671b-cloud` +- `qwen3-coder:480b-cloud` + +### Get started + +To run a cloud model, open the terminal and run: + +``` +ollama run gpt-oss:120b-cloud +``` + +To run cloud models with integrations that work with Ollama, first download the cloud model: + +``` +ollama pull qwen3-coder:480b-cloud +``` + +Then sign in to Ollama: + +``` +ollama signin +``` + +Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling. + +## Cloud API access + +Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud). diff --git a/docs/development.md b/docs/development.md index 9726b5d91..ff07b5fb6 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository: go run . serve ``` +> [!NOTE] +> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first. + + ## macOS (Apple Silicon) macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required. diff --git a/docs/gpu.md b/docs/gpu.md index 464788ccb..ec5c9ccc3 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -65,6 +65,9 @@ With ROCm v6.1, the following GPUs are supported on Windows. | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` | +### Known Workarounds + +- The RX Vega 56 requires `HSA_ENABLE_SDMA=0` to disable SDMA ### Overrides on Linux Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In diff --git a/docs/linux.md b/docs/linux.md index 9a156d1dc..ce5ed860b 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh ## Manual install > [!NOTE] -> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. +> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. Download and extract the package: ```shell curl -LO https://ollama.com/download/ollama-linux-amd64.tgz +sudo rm -rf /usr/lib/ollama sudo tar -C /usr -xzf ollama-linux-amd64.tgz ``` diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 6fdd3e85b..7647b12f9 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an - Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs - Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia` +You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting +- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1` + ## AMD GPU Discovery diff --git a/docs/turbo.md b/docs/turbo.md deleted file mode 100644 index d75d95570..000000000 --- a/docs/turbo.md +++ /dev/null @@ -1,107 +0,0 @@ -# Turbo - -> ⚠️ Turbo is preview - -Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware. - -Currently, the following models are available in Turbo: - -- `gpt-oss:20b` -- `gpt-oss:120b` - -## Get started - -### Ollama for macOS & Windows - -Download Ollama - -- Select a model such as `gpt-oss:20b` or `gpt-oss:120b` -- Click on **Turbo**. You’ll be prompted to create an account or sign in - -### Ollama’s CLI - -- [Sign up](https://ollama.com/signup) for an Ollama account -- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys). - - On macOS and Linux: - - ```shell - cat ~/.ollama/id_ed25519.pub - ``` - - On Windows: - - ``` - type "%USERPROFILE%\.ollama\id_ed25519.pub" - ``` - -- Then run a model setting `OLLAMA_HOST` to `ollama.com`: - ```shell - OLLAMA_HOST=ollama.com ollama run gpt-oss:120b - ``` - -### Ollama’s Python library - -- Download Ollama's [Python library](https://github.com/ollama/ollama-python) -- [Sign up](https://ollama.com/signup) for an Ollama account -- Create an API key by visiting https://ollama.com/settings/keys - -```python -from ollama import Client - -client = Client( - host="https://ollama.com", - headers={'Authorization': ''} -) - -messages = [ - { - 'role': 'user', - 'content': 'Why is the sky blue?', - }, -] - -for part in client.chat('gpt-oss:120b', messages=messages, stream=True): - print(part['message']['content'], end='', flush=True) -``` - -### Ollama’s JavaScript library - -- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js) -- [Sign up](https://ollama.com/signup) for an Ollama account -- Create an API key by visiting https://ollama.com/settings/keys - -```typescript -import { Ollama } from 'ollama'; - -const ollama = new Ollama({ - host: 'https://ollama.com', - headers: { - Authorization: "Bearer " - } -}); - -const response = await ollama.chat({ - model: 'gpt-oss:120b', - messages: [{ role: 'user', content: 'Explain quantum computing' }], - stream: true -}); - -for await (const part of response) { - process.stdout.write(part.message.content) -} -``` - -### Community integrations - -Turbo mode is also compatible with several community integrations. - -#### Open WebUI - -- Go to **settings** → **Admin settings** → **Connections** -- Under **Ollama API,** click **+** -- For the **URL** put `https://ollama.com` -- For the **API key,** create an API key on https://ollama.com/settings/keys and add it. -- Click **Save** - -Now, if you navigate to the model selector, Turbo models should be available under **External**. diff --git a/envconfig/config.go b/envconfig/config.go index 868813ae8..09243ab95 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) { return loadTimeout } +func Remotes() []string { + var r []string + raw := strings.TrimSpace(Var("OLLAMA_REMOTES")) + if raw == "" { + r = []string{"ollama.com"} + } else { + r = strings.Split(raw, ",") + } + return r +} + func Bool(k string) func() bool { return func() bool { if s := Var(k); s != "" { @@ -185,8 +196,6 @@ var ( ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) // Auth enables authentication between the Ollama client and server UseAuth = Bool("OLLAMA_AUTH") - // Enable the new memory estimation logic - NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES") ) func String(s string) func() string { @@ -272,7 +281,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, - "OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"}, + "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index a739e99ba..58803f58f 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "log/slog" + "math" "slices" "strings" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/util/bufioutil" ) @@ -55,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 { return uint64(kv.Uint("embedding_length")) } +func (kv KV) HeadCount() []uint64 { + headCountDefault := uint32(1) + headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault) + if len(headCount) == 1 { + headCountDefault = headCount[0] + } + nLayers := int(kv.BlockCount()) + if len(headCount) > nLayers { + slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(headCount) { + out[i] = uint64(headCountDefault) + } else { + out[i] = uint64(headCount[i]) + } + } + return out +} + func (kv KV) HeadCountMax() uint64 { - // TODO(drifkin): using the max value can cause an overestimation. In the - // future if array values become more popular, we can adapt the more invasive - // return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1)) } @@ -66,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 { return uint64(kv.UintOrMinArrayValue("attention.head_count", 1)) } +func (kv KV) HeadCountKV() []uint64 { + headCountKVDefault := uint32(1) + headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault) + if len(headCountKV) == 1 { + headCountKVDefault = headCountKV[0] + } + nLayers := int(kv.BlockCount()) + if len(headCountKV) > nLayers { + slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(headCountKV) { + out[i] = uint64(headCountKVDefault) + } else { + out[i] = uint64(headCountKV[i]) + } + } + return out +} + func (kv KV) HeadCountKVMax() uint64 { return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1)) } @@ -98,6 +139,26 @@ func (kv KV) ChatTemplate() string { return kv.String("tokenizer.chat_template") } +// ssm architecture parameters + +func (kv KV) SSMConvKernel() uint64 { + return uint64(kv.Uint("ssm.conv_kernel")) +} + +func (kv KV) SSMInnerSize() uint64 { + return uint64(kv.Uint("ssm.inner_size")) +} + +func (kv KV) SSMStateSize() uint64 { + return uint64(kv.Uint("ssm.state_size")) +} + +func (kv KV) SSMGroupCount() uint64 { + return uint64(kv.Uint("ssm.group_count")) +} + +// general types + func (kv KV) String(key string, defaultValue ...string) string { val, _ := keyValue(kv, key, append(defaultValue, "")...) return val @@ -129,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 { } func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) { + arrVal := kv.UintOrArrayValueAsArray(key, defaultValue) + return slices.Min(arrVal), slices.Max(arrVal) +} + +func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 { if u32, ok := keyValue(kv, key, uint32(0)); ok { - return u32, u32 + return []uint32{u32} } else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok { - min := slices.Min(u32s.values) - max := slices.Max(u32s.values) - return min, max + return u32s.values } else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok { - min := slices.Min(i32s.values) - max := slices.Max(i32s.values) - if min < 0 || max < 0 { - slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max) + dst := make([]uint32, len(i32s.values)) + for i, v := range i32s.values { + if v < 0 { + slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v) + } + dst[i] = uint32(v) } - return uint32(min), uint32(max) + return dst } - return defaultValue, defaultValue + return []uint32{defaultValue} } func (kv KV) Strings(key string, defaultValue ...[]string) []string { @@ -177,6 +243,8 @@ func (kv KV) OllamaEngineRequired() bool { "gemma3", "gemma3n", "mistral3", + "qwen3", + "qwen3moe", "llama4", "mllama", "qwen25vl", @@ -275,7 +343,7 @@ type Tensor struct { func (t Tensor) block() (n int) { if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil { - return -1 + return math.MaxInt } return @@ -288,24 +356,24 @@ func (t Tensor) blockSize() uint64 { func (t TensorType) BlockSize() uint64 { switch t { case - 0, // F32 - 1, // F16 - 24, // I8 - 25, // I16 - 26, // I32 - 27, // I64 - 28, // F64 - 30: // BF16 + TensorTypeF32, + TensorTypeF16, + TensorTypeI8, + TensorTypeI16, + TensorTypeI32, + TensorTypeI64, + TensorTypeF64, + TensorTypeBF16: return 1 case - 2, // Q4_0 - 3, // Q4_1 - 4, // MXFP4 - 6, // Q5_0 - 7, // Q5_1 - 8, // Q8_0 - 9, // Q8_1 - 20: // IQ4_NL + TensorTypeQ4_0, + TensorTypeQ4_1, + TensorTypeQ5_0, + TensorTypeQ5_1, + TensorTypeQ8_0, + TensorTypeQ8_1, + tensorTypeIQ4_NL, + 4, TensorTypeMXFP4: return 32 default: return 256 @@ -328,8 +396,6 @@ func (t TensorType) TypeSize() uint64 { return 2 + blockSize/2 case TensorTypeQ4_1: return 2 + 2 + blockSize/2 - case TensorTypeMXFP4, 39: - return 1 + blockSize/2 case TensorTypeQ5_0: return 2 + 4 + blockSize/2 case TensorTypeQ5_1: @@ -380,6 +446,8 @@ func (t TensorType) TypeSize() uint64 { return blockSize/8 + blockSize/16 + blockSize/32 case TensorTypeBF16: return 2 + case 4, TensorTypeMXFP4: + return 1 + blockSize/2 default: return 0 } @@ -479,12 +547,14 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { }, nil } -func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { +func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) { context *= uint64(numParallel) embedding := f.KV().EmbeddingLength() heads := f.KV().HeadCountMax() + headsArr := f.KV().HeadCount() headsKV := f.KV().HeadCountKVMax() + headsKVArr := f.KV().HeadCountKV() vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) embeddingHeads := f.KV().EmbeddingHeadCountMax() @@ -494,12 +564,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri layers := f.Tensors().GroupLayers() bytesPerElement := kvCacheBytesPerElement(kvCacheType) + + // Default for models unless special-cased below. These defaults mirror the + // cache usage in llama.cpp under the assumption that models without special + // cases below will use the llamarunner and caching will be handled by the + // llama.cpp layer. + // + // This also assumes that a layer without heads or headsKV set is recurrent + // which is usually the case. Some models (eg nemotronh) use "blocks" in + // place of layers where some are MLP blocks that don't have any cache. + // Models like this will need a special case below to be accurately + // estimated. var kvTotal uint64 kv = make([]uint64, f.KV().BlockCount()) + kvSizeAttn := uint64(0) + kvSizeRecurrent := uint64(0) for i := range kv { - kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + headsL := headsArr[i] + headsKVL := headsKVArr[i] + if headsL > 0 && headsKVL > 0 { + // full attention layer + // NOTE: Assumes uniform values for all attn layers + kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement) + kvSizeAttn += kv[i] + } else { + // recurrent layer + ssmDConv := f.KV().SSMConvKernel() + ssmDState := f.KV().SSMStateSize() + ssmDInner := f.KV().SSMInnerSize() + ssmNGroups := f.KV().SSMGroupCount() + nEmbdR := uint64(0) + if ssmDConv > 0 { + nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState) + } + nEmbdS := ssmDState * ssmDInner + + // recurrent always uses F32 in llama.cpp backend + // https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644 + bytesPerElementRecurrent := kvCacheBytesPerElement("f32") + + kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent) + kvSizeRecurrent += kv[i] + } kvTotal += kv[i] } + slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent) switch f.KV().Architecture() { case "llama", "llama4": @@ -677,7 +786,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri kv[i] *= context } } + partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 + if useFlashAttention { + // rough estimate of graph size with flash attention on + partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte + } } return @@ -752,12 +866,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { // SupportsKVCacheType checks if the requested cache type is supported func (f GGML) SupportsKVCacheType(cacheType string) bool { + if cacheType == "" || cacheType == "f16" { + return true + } + if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) { // gpt-oss uses attention with sinks which does not support quantized cache types - slog.Warn("model only supports non-quantized cache types ", "mode", arch) - return cacheType == "f16" + slog.Warn("model only supports non-quantized cache types", "model", arch) + return false } - return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) + return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) } // SupportsFlashAttention checks if the model supports flash attention @@ -767,12 +885,23 @@ func (f GGML) SupportsFlashAttention() bool { return false } + if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) { + return false + } + // Check head counts match and are non-zero headCountK := f.KV().EmbeddingHeadCountK() headCountV := f.KV().EmbeddingHeadCountV() return headCountK != 0 && headCountV != 0 && headCountK == headCountV } +// FlashAttention checks if the model should enable flash attention +func (f GGML) FlashAttention() bool { + return slices.Contains([]string{ + "gptoss", "gpt-oss", + }, f.KV().String("general.architecture")) +} + // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type func kvCacheBytesPerElement(cacheType string) float64 { switch cacheType { @@ -780,6 +909,8 @@ func kvCacheBytesPerElement(cacheType string) float64 { return 1 // 1/2 of fp16 case "q4_0": return 0.5 // 1/4 of fp16 + case "f32": + return 4 // f32 (default for recurrent) default: return 2 // f16 (default) } diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index 413eab5ed..fa613ca4b 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -533,12 +533,15 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { } } - slices.SortStableFunc(ts, func(a, b *Tensor) int { - if i, j := a.block(), b.block(); i > 0 && j > 0 { - return cmp.Compare(i, j) - } - return cmp.Compare(a.Name, b.Name) - }) + slices.SortStableFunc( + ts, + func(a, b *Tensor) int { + return cmp.Or( + cmp.Compare(a.block(), b.block()), + cmp.Compare(a.Name, b.Name), + ) + }, + ) var s uint64 for i := range ts { diff --git a/fs/ggml/gguf_test.go b/fs/ggml/gguf_test.go index bf7679182..e56bab8d2 100644 --- a/fs/ggml/gguf_test.go +++ b/fs/ggml/gguf_test.go @@ -11,24 +11,24 @@ import ( ) func TestWriteGGUF(t *testing.T) { - r := rand.New(rand.NewPCG(0, 0)) + b := bytes.NewBuffer(make([]byte, 2*3)) for range 8 { t.Run("shuffle", func(t *testing.T) { t.Parallel() ts := []*Tensor{ - {Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, - {Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, - {Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, + {Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b}, + {Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b}, + {Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b}, } - r.Shuffle(len(ts), func(i, j int) { + rand.Shuffle(len(ts), func(i, j int) { ts[i], ts[j] = ts[j], ts[i] }) @@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) { } if diff := cmp.Diff(Tensors{ - Offset: 608, + Offset: 592, items: []*Tensor{ - {Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}}, - {Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}}, - {Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}}, - {Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}}, - {Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}}, - {Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}}, + {Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}}, + {Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}}, + {Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}}, + {Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}}, + {Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}}, + {Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}}, {Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}}, {Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}}, {Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}}, diff --git a/fs/ggml/type.go b/fs/ggml/type.go index 3e5deb87b..1a31a5fd8 100644 --- a/fs/ggml/type.go +++ b/fs/ggml/type.go @@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType { return TensorTypeQ4_0 case fileTypeQ4_1: return TensorTypeQ4_1 - case fileTypeMXFP4: - return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 case FileTypeQ8_0: return TensorTypeQ8_0 case fileTypeQ5_0: @@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType { return TensorTypeQ2_K case FileTypeBF16: return TensorTypeBF16 + case fileTypeMXFP4: + return TensorTypeMXFP4 default: slog.Warn("unsupported file type", "type", ftype) return 0 // F32 @@ -191,8 +191,8 @@ const ( TensorTypeF16 TensorTypeQ4_0 TensorTypeQ4_1 - TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 - tensorTypeQ4_3 // unused by GGML + tensorTypeQ4_2 + tensorTypeQ4_3 // unused by GGML TensorTypeQ5_0 TensorTypeQ5_1 TensorTypeQ8_0 @@ -226,6 +226,7 @@ const ( tensorTypeIQ4_NL_4_4 // unused by GGML tensorTypeIQ4_NL_4_8 // unused by GGML tensorTypeIQ4_NL_8_8 // unused by GGML + TensorTypeMXFP4 ) // ParseFileType parses the provided GGUF file type @@ -318,7 +319,7 @@ func (t TensorType) String() string { return "F64" case TensorTypeBF16: return "BF16" - case TensorTypeMXFP4: + case 4, TensorTypeMXFP4: return "MXFP4" default: return "unknown" diff --git a/server/harmonyparser.go b/harmony/harmonyparser.go similarity index 81% rename from server/harmonyparser.go rename to harmony/harmonyparser.go index 4405cea44..da9fe3e93 100644 --- a/server/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -1,10 +1,9 @@ -package server +package harmony import ( - "context" + "encoding/json" "fmt" "log/slog" - "slices" "strings" "unicode" @@ -20,18 +19,6 @@ const ( harmonyParserState_ParsingContent ) -func shouldUseHarmony(model Model) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { - return true - } - } - - return false -} - func (s harmonyParserState) String() string { switch s { // we're looking for the message start tag @@ -277,20 +264,23 @@ const ( // This is a higher level interface that maps harmony concepts into ollama concepts type HarmonyMessageHandler struct { state harmonyMessageState - harmonyParser *HarmonyParser - functionNameMap *FunctionNameMap + HarmonyParser *HarmonyParser + FunctionNameMap *FunctionNameMap + toolAccumulator *HarmonyToolCallAccumulator + convertedTools map[string]struct{} } // NewHarmonyMessageHandler creates a new message handler func NewHarmonyMessageHandler() *HarmonyMessageHandler { return &HarmonyMessageHandler{ state: harmonyMessageState_Normal, - harmonyParser: &HarmonyParser{ + HarmonyParser: &HarmonyParser{ MessageStartTag: "<|start|>", MessageEndTag: "<|end|>", HeaderEndTag: "<|message|>", }, - functionNameMap: NewFunctionNameMap(), + FunctionNameMap: NewFunctionNameMap(), + convertedTools: make(map[string]struct{}), } } @@ -301,11 +291,11 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} - events := h.harmonyParser.AddContent(content) + events := h.HarmonyParser.AddContent(content) for _, event := range events { switch event := event.(type) { case HarmonyEventHeaderComplete: - slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header) + logutil.Trace("harmony event header complete", "header", event.Header) switch event.Header.Channel { case "analysis": if event.Header.Recipient != "" { @@ -328,7 +318,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo h.state = harmonyMessageState_Normal } case HarmonyEventContentEmitted: - slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state) + logutil.Trace("harmony event content", "content", event.Content, "state", h.state) if h.state == harmonyMessageState_Normal { contentSb.WriteString(event.Content) } else if h.state == harmonyMessageState_Thinking { @@ -398,8 +388,85 @@ func NewFunctionNameMap() *FunctionNameMap { } } +// Init initializes the handler with tools and optional last message +// Implements the Parser interface +func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + // Initialize the harmony parser + if h.HarmonyParser == nil { + h.HarmonyParser = &HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + } + + // Handle prefill for chat mode + if lastMessage != nil { + h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) + } else { + h.HarmonyParser.AddImplicitStart() + } + + // Initialize tool accumulator + h.toolAccumulator = h.CreateToolParser() + + // Process tools and return renamed versions + if len(tools) == 0 { + return tools + } + + processedTools := make([]api.Tool, len(tools)) + copy(processedTools, tools) + for i, tool := range processedTools { + if tool.Function.Name != "" { + processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + h.convertedTools[tool.Function.Name] = struct{}{} + } + } + return processedTools +} + +// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls +func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + content, thinking, toolContent := h.AddContent(s, h.toolAccumulator) + if toolContent != "" { + h.toolAccumulator.Add(toolContent) + } + + // tool calls always happen one at a time, and always at the end of a message, + // so for simplicity we defer parsing them until we know we're done + if done { + toolName, raw := h.toolAccumulator.Drain() + if toolName != nil { + name := strings.TrimPrefix(*toolName, "functions.") + name = h.FunctionNameMap.OriginalFromConverted(name) + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err) + } + calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}}) + } + } + + return content, thinking, calls, nil +} + +// HasToolSupport implements the Parser interface +func (h *HarmonyMessageHandler) HasToolSupport() bool { + return true +} + +// HasThinkingSupport implements the Parser interface +func (h *HarmonyMessageHandler) HasThinkingSupport() bool { + return true +} + func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { harmonyFunctionName := m.deriveName(userFunctionName) + // built-in functions should not be renamed + if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" { + harmonyFunctionName = userFunctionName + } m.userToHarmony[userFunctionName] = harmonyFunctionName m.harmonyToUser[harmonyFunctionName] = userFunctionName return harmonyFunctionName diff --git a/server/harmonyparser_test.go b/harmony/harmonyparser_test.go similarity index 98% rename from server/harmonyparser_test.go rename to harmony/harmonyparser_test.go index 8a22f3404..e56178c61 100644 --- a/server/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -1,4 +1,4 @@ -package server +package harmony import ( "fmt" @@ -513,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) { {name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}}, {name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}}, {name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}}, + {name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}}, } for i, tt := range tests { diff --git a/integration/README.md b/integration/README.md index e2bdd6b21..1dfd0e359 100644 --- a/integration/README.md +++ b/integration/README.md @@ -2,10 +2,16 @@ This directory contains integration tests to exercise Ollama end-to-end to verify behavior -By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` +By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"` The integration tests have 2 modes of operating. 1. By default, they will start the server on a random port, run the tests, and then shutdown the server. -2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote +2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable + +> [!IMPORTANT] +> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. + + +Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL` \ No newline at end of file diff --git a/integration/api_test.go b/integration/api_test.go index d24f5001f..48572085d 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue? be brief", + Prompt: blueSkyPrompt, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) { // Verify the response contains the expected data response := buf.String() atLeastOne := false - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { - t.Errorf("none of %v found in %s", anyResp, response) + t.Errorf("none of %v found in %s", blueSkyExpected, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for generate") @@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) { Messages: []api.Message{ { Role: "user", - Content: "why is the sky blue? be brief", + Content: blueSkyPrompt, }, }, Options: map[string]interface{}{ @@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) { "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) { // Verify the response contains the expected data response := buf.String() atLeastOne := false - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { - t.Errorf("none of %v found in %s", anyResp, response) + t.Errorf("none of %v found in %s", blueSkyExpected, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for chat") @@ -390,7 +388,7 @@ func TestAPIEmbeddings(t *testing.T) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() req := api.EmbeddingRequest{ - Model: "orca-mini", + Model: libraryEmbedModels[0], Prompt: "why is the sky blue?", Options: map[string]interface{}{ "temperature": 0, @@ -410,3 +408,99 @@ func TestAPIEmbeddings(t *testing.T) { t.Errorf("zero length embedding response") } } + +func TestAPIToolCalling(t *testing.T) { + initialTimeout := 60 * time.Second + streamTimeout := 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "qwen3:0.6b" + if err := PullIfMissing(ctx, client, modelName); err != nil { + t.Fatalf("pull failed %s", err) + } + + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state, e.g. San Francisco, CA", + }, + }, + }, + }, + }, + } + + req := api.ChatRequest{ + Model: modelName, + Messages: []api.Message{ + { + Role: "user", + Content: "Call get_weather with location set to San Francisco.", + }, + }, + Tools: tools, + Options: map[string]any{ + "temperature": 0, + }, + } + + stallTimer := time.NewTimer(initialTimeout) + var gotToolCall bool + var lastToolCall api.ToolCall + + fn := func(response api.ChatResponse) error { + if len(response.Message.ToolCalls) > 0 { + gotToolCall = true + lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1] + } + if !stallTimer.Reset(streamTimeout) { + return fmt.Errorf("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + req.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Chat(ctx, &req, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String()) + case <-done: + if genErr != nil { + t.Fatalf("chat failed: %v", genErr) + } + + if !gotToolCall { + t.Fatalf("expected at least one tool call, got none") + } + + if lastToolCall.Function.Name != "get_weather" { + t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather") + } + + if _, ok := lastToolCall.Function.Arguments["location"]; !ok { + t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String()) + } + case <-ctx.Done(): + t.Error("outer test context done while waiting for tool-calling chat") + } +} diff --git a/integration/basic_test.go b/integration/basic_test.go index 13c2f22a2..0a6b9253d 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/ollama/ollama/api" - "github.com/stretchr/testify/require" ) func TestBlueSky(t *testing.T) { @@ -20,14 +19,14 @@ func TestBlueSky(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, blueSkyExpected) } func TestUnicode(t *testing.T) { @@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) { // Set up the test data req := api.GenerateRequest{ // DeepSeek has a Unicode tokenizer regex, making it a unicode torture test - Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", - Prompt: "天空为什么是蓝色的?", + Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage? + Prompt: "天空为什么是蓝色的?", // Why is the sky blue? Stream: &stream, Options: map[string]any{ "temperature": 0, @@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) { } client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) - DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second) + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatal(err) + } + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil }) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + skipIfNotGPULoaded(ctx, t, client, req.Model, 100) + + DoGenerate(ctx, t, client, req, []string{ + "散射", // scattering + "频率", // frequency + }, 120*time.Second, 120*time.Second) } func TestExtendedUnicodeOutput(t *testing.T) { @@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) { } client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatal(err) + } DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second) } @@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) { } modelDir, err := os.MkdirTemp("", "ollama_埃") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } defer os.RemoveAll(modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir) @@ -95,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) { req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, blueSkyExpected) } diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 52a7f36bd..3104eacca 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -14,8 +14,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" @@ -79,21 +77,21 @@ func TestMultiModelStress(t *testing.T) { t.Fatal(err) } + // All models compatible with ollama-engine smallModels := []string{ "llama3.2:1b", "qwen3:0.6b", - "gemma:2b", - "deepseek-r1:1.5b", - "starcoder2:3b", + "gemma2:2b", + "deepseek-r1:1.5b", // qwen2 arch + "gemma3:270m", } mediumModels := []string{ - "qwen3:8b", - "llama2", - "deepseek-r1:7b", - "mistral", - "dolphin-mistral", - "gemma:7b", - "codellama:7b", + "llama3.2:3b", // ~3.4G + "qwen3:8b", // ~6.6G + "gpt-oss:20b", // ~15G + "deepseek-r1:7b", // ~5.6G + "gemma3:4b", // ~5.8G + "gemma2:9b", // ~8.1G } var chosenModels []string @@ -114,13 +112,16 @@ func TestMultiModelStress(t *testing.T) { // Make sure all the models are pulled before we get started for _, model := range chosenModels { - require.NoError(t, PullIfMissing(ctx, client, model)) + if err := PullIfMissing(ctx, client, model); err != nil { + t.Fatal(err) + } } // Determine how many models we can load in parallel before we exceed VRAM // The intent is to go 1 over what can fit so we force the scheduler to thrash targetLoadCount := 0 slog.Info("Loading models to find how many can fit in VRAM before overflowing") +chooseModels: for i, model := range chosenModels { req := &api.GenerateRequest{Model: model} slog.Info("loading", "model", model) @@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) { slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount]) break } + // Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts + for _, m := range models.Models { + if m.SizeVRAM == 0 { + slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount]) + break chooseModels + } + } } } if targetLoadCount == len(chosenModels) { diff --git a/integration/context_test.go b/integration/context_test.go index b28d11380..9d13f7acb 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) { defer cancel() // Set up the test data req := api.GenerateRequest{ - Model: "llama2", + Model: smol, Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", Stream: &stream, Options: map[string]any{ @@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second) + DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) } func TestContextExhaustion(t *testing.T) { @@ -49,8 +49,8 @@ func TestContextExhaustion(t *testing.T) { defer cancel() // Set up the test data req := api.GenerateRequest{ - Model: "llama2", - Prompt: "Write me a story with a ton of emojis?", + Model: smol, + Prompt: "Write me a story in english with a lot of emojis", Stream: &stream, Options: map[string]any{ "temperature": 0, @@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) + DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second) } -// Send multiple requests with prior context and ensure the response is coherant and expected -func TestGenerateWithHistory(t *testing.T) { +// Send multiple generate requests with prior context and ensure the response is coherant and expected +func TestParallelGenerateWithHistory(t *testing.T) { modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model req, resp := GenerateRequests() numParallel := 2 @@ -111,5 +111,148 @@ func TestGenerateWithHistory(t *testing.T) { }(i) } wg.Wait() - +} + +// Send generate requests with prior context and ensure the response is coherant and expected +func TestGenerateWithHistory(t *testing.T) { + req := api.GenerateRequest{ + Model: smol, + Prompt: rainbowPrompt, + Stream: &stream, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + Options: map[string]any{ + "num_ctx": 16384, + }, + } + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial request + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + + req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + + for i := 0; i < len(rainbowFollowups); i++ { + req.Prompt = rainbowFollowups[i] + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + } +} + +// Send multiple chat requests with prior context and ensure the response is coherant and expected +func TestParallelChatWithHistory(t *testing.T) { + modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model + req, resp := ChatRequests() + numParallel := 2 + iterLimit := 2 + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial empty request + slog.Info("loading", "model", modelOverride) + err := client.Generate(ctx, + &api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", modelOverride, err) + } + + var wg sync.WaitGroup + wg.Add(numParallel) + for i := range numParallel { + go func(i int) { + defer wg.Done() + k := i % len(req) + req[k].Model = modelOverride + for j := 0; j < iterLimit; j++ { + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + slog.Info("Starting", "thread", i, "iter", j) + // On slower GPUs it can take a while to process the concurrent requests + // so we allow a much longer initial timeout + assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second) + if assistant == nil { + t.Fatalf("didn't get an assistant response for context") + } + req[k].Messages = append(req[k].Messages, + *assistant, + api.Message{Role: "user", Content: "tell me more!"}, + ) + } + }(i) + } + wg.Wait() +} + +// Send generate requests with prior context and ensure the response is coherant and expected +func TestChatWithHistory(t *testing.T) { + req := api.ChatRequest{ + Model: smol, + Stream: &stream, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + Options: map[string]any{ + "num_ctx": 16384, + }, + Messages: []api.Message{ + { + Role: "user", + Content: rainbowPrompt, + }, + }, + } + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial request + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + + assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + + for i := 0; i < len(rainbowFollowups); i++ { + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + req.Messages = append(req.Messages, + *assistant, + api.Message{Role: "user", Content: rainbowFollowups[i]}, + ) + + assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + if assistant == nil { + t.Fatalf("didn't get an assistant response for context") + } + } } diff --git a/integration/embed_test.go b/integration/embed_test.go index 09369dbb4..a68524486 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) @@ -38,14 +39,14 @@ func TestAllMiniLMEmbeddings(t *testing.T) { defer cleanup() req := api.EmbeddingRequest{ - Model: "all-minilm", - Prompt: "why is the sky blue?", + Model: "all-minilm", + Prompt: "why is the sky blue?", + KeepAlive: &api.Duration{Duration: 10 * time.Second}, } res, err := embeddingTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embedding) != 384 { @@ -73,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) { } res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embeddings) != 1 { @@ -111,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { } res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embeddings) != 2 { @@ -155,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { truncTrue, truncFalse := true, false - type testReq struct { - Name string - Request api.EmbedRequest + want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ + Model: "all-minilm", + Input: "why", + }) + if err != nil { + t.Fatal(err) } - reqs := []testReq{ + cases := []struct { + name string + request api.EmbedRequest + check func(*api.EmbedResponse, error) + }{ { - Name: "Target Truncation", - Request: api.EmbedRequest{ + name: "target truncation", + request: api.EmbedRequest{ Model: "all-minilm", Input: "why", }, - }, - { - Name: "Default Truncate", - Request: api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - Options: map[string]any{"num_ctx": 1}, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } }, }, { - Name: "Explicit Truncate", - Request: api.EmbedRequest{ + name: "default truncate", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Options: map[string]any{"num_ctx": 3}, + }, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } + }, + }, + { + name: "explicit truncate", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 3}, + }, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } + }, + }, + { + name: "truncate error", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 3}, + }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, + }, + { + name: "input after truncate error", + request: api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncTrue, Options: map[string]any{"num_ctx": 1}, }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input after truncation exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, + }, + { + name: "input after truncate error", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 0}, + }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input after truncation exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, }, } - res := make(map[string]*api.EmbedResponse) - - for _, req := range reqs { - response, err := embedTestHelper(ctx, client, t, req.Request) - if err != nil { - t.Fatalf("error: %v", err) - } - res[req.Name] = response - } - - if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { - t.Fatal("expected default request to truncate correctly") - } - - if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { - t.Fatal("expected default request and truncate true request to be the same") - } - - // check that truncate set to false returns an error if context length is exceeded - _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 1}, - }) - - if err == nil { - t.Fatal("expected error, got nil") + for _, req := range cases { + t.Run(req.name, func(t *testing.T) { + req.check(embedTestHelper(ctx, client, t, req.request)) + }) } } func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { + t.Helper() + if err := PullIfMissing(ctx, client, req.Model); err != nil { - t.Fatalf("failed to pull model %s: %v", req.Model, err) + t.Fatal(err) } - response, err := client.Embeddings(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil + return client.Embeddings(ctx, &req) } func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { + t.Helper() + if err := PullIfMissing(ctx, client, req.Model); err != nil { - t.Fatalf("failed to pull model %s: %v", req.Model, err) + t.Fatal(err) } - response, err := client.Embed(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil + return client.Embed(ctx, &req) } diff --git a/integration/library_models_test.go b/integration/library_models_test.go index cdf65efc8..49e1097b8 100644 --- a/integration/library_models_test.go +++ b/integration/library_models_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "fmt" "log/slog" + "os" "testing" "time" @@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") chatModels := libraryChatModels for _, model := range chatModels { @@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) { if err := PullIfMissing(ctx, client, model); err != nil { t.Fatalf("pull failed %s", err) } + if targetArch != "" { + resp, err := client.Show(ctx, &api.ShowRequest{Name: model}) + if err != nil { + t.Fatalf("unable to show model: %s", err) + } + arch := resp.ModelInfo["general.architecture"].(string) + if arch != targetArch { + t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) + } + } req := api.GenerateRequest{ Model: model, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: map[string]interface{}{ "temperature": 0.1, "seed": 123, }, } - anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"} + anyResp := blueSkyExpected // Special cases if model == "duckdb-nsql" { anyResp = []string{"select", "from"} diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index bbd031a93..9bf11257c 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/ollama/ollama/api" - "github.com/stretchr/testify/require" ) func TestVisionModels(t *testing.T) { @@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) { for _, v := range testCases { t.Run(v.model, func(t *testing.T) { image, err := base64.StdEncoding.DecodeString(imageEncoding) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } req := api.GenerateRequest{ Model: v.model, Prompt: "what does the text in this image say?", @@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) { // Note: sometimes it returns "the ollamas" sometimes "the ollams" resp := "the ollam" defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatal(err) + } // llava models on CPU can be quite slow to start DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) }) @@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) { func TestIntegrationSplitBatch(t *testing.T) { skipUnderMinVRAM(t, 6) image, err := base64.StdEncoding.DecodeString(imageEncoding) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } req := api.GenerateRequest{ Model: "gemma3:4b", // Fill up a chunk of the batch so the image will partially spill over into the next one @@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatal(err) + } // llava models on CPU can be quite slow to start, DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) } diff --git a/integration/llm_test.go b/integration/llm_test.go deleted file mode 100644 index 50249bf0f..000000000 --- a/integration/llm_test.go +++ /dev/null @@ -1,47 +0,0 @@ -//go:build integration - -package integration - -import ( - "context" - "testing" - "time" - - "github.com/ollama/ollama/api" -) - -// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server -// package to avoid circular dependencies - -var ( - stream = false - req = [2]api.GenerateRequest{ - { - Model: smol, - Prompt: "why is the ocean blue?", - Stream: &stream, - Options: map[string]any{ - "seed": 42, - "temperature": 0.0, - }, - }, { - Model: smol, - Prompt: "what is the origin of the us thanksgiving holiday?", - Stream: &stream, - Options: map[string]any{ - "seed": 42, - "temperature": 0.0, - }, - }, - } - resp = [2][]string{ - {"sunlight", "scattering", "interact"}, - {"england", "english", "massachusetts", "pilgrims"}, - } -) - -func TestIntegrationSimple(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) - defer cancel() - GenerateTestHelper(ctx, t, req[0], resp[0]) -} diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go index 7bb9336a0..24e3101f2 100644 --- a/integration/max_queue_test.go +++ b/integration/max_queue_test.go @@ -13,12 +13,12 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/ollama/ollama/api" ) func TestMaxQueue(t *testing.T) { + t.Skip("this test needs to be re-evaluated to use a proper embedding model") + if os.Getenv("OLLAMA_TEST_EXISTING") != "" { t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size") return @@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatal(err) + } // Context for the worker threads so we can shut them down // embedCtx, embedCancel := context.WithCancel(ctx) @@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) { switch { case genErr == nil: successCount++ - require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable + if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable + t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding)) + } case errors.Is(genErr, context.Canceled): canceledCount++ case strings.Contains(genErr.Error(), "busy"): @@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) { case strings.Contains(genErr.Error(), "connection reset by peer"): resetByPeerCount++ default: - require.NoError(t, genErr, "%d request failed", i) + if genErr != nil { + t.Fatalf("%d request failed", i) + } } slog.Info("embed finished", "id", i) @@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) { embedwg.Wait() slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount) - require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?") - require.True(t, busyCount > 0, "no requests hit busy error but some should have") - require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout") - + if resetByPeerCount != 0 { + t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount) + } + if busyCount == 0 { + t.Fatalf("no requests hit busy error but some should have") + } + if canceledCount > 0 { + t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount) + } } diff --git a/integration/model_arch_test.go b/integration/model_arch_test.go index 9fc2e01dd..721d95c54 100644 --- a/integration/model_arch_test.go +++ b/integration/model_arch_test.go @@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) { // TODO - fiddle with context size req := api.GenerateRequest{ Model: model, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"} - DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second) + DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second) }) } } diff --git a/integration/model_perf_test.go b/integration/model_perf_test.go index 759e8b9a2..3d6ba9239 100644 --- a/integration/model_perf_test.go +++ b/integration/model_perf_test.go @@ -40,6 +40,18 @@ var ( // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv func TestModelsPerf(t *testing.T) { + if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { + doModelPerfTest(t, ollamaEngineChatModels) + } else { + doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...)) + } +} + +func TestLibraryModelsPerf(t *testing.T) { + doModelPerfTest(t, libraryChatModels) +} + +func doModelPerfTest(t *testing.T, chatModels []string) { softTimeout, hardTimeout := getTimeouts(t) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) @@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) { } longPrompt := "summarize the following: " + string(data) - var chatModels []string - if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { - chatModels = ollamaEngineChatModels - } else { - chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...) - } + targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") for _, model := range chatModels { + if !strings.Contains(model, ":") { + model = model + ":latest" + } t.Run(model, func(t *testing.T) { if time.Now().Sub(started) > softTimeout { t.Skip("skipping remaining tests to avoid excessive runtime") @@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) { } arch := resp.ModelInfo["general.architecture"].(string) maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) + if targetArch != "" && arch != targetArch { + t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) + } if maxVram > 0 { resp, err := client.List(ctx) @@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) { prompt string anyResp []string }{ - {"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}}, - {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}}, + {blueSkyPrompt, blueSkyExpected}, + {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}}, } var gpuPercent int for _, tc := range testCases { @@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) { } } } + // Round the logged prompt count for comparisons across versions/configurations which can vary slightly fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", "MODEL", "CONTEXT", "GPU PERCENT", - "PROMPT COUNT", + "APPROX PROMPT COUNT", "LOAD TIME", "PROMPT EVAL TPS", "EVAL TPS", @@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) { model, numCtx, gpuPercent, - resp.PromptEvalCount, + (resp.PromptEvalCount/10)*10, float64(resp.LoadDuration)/1000000000.0, float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), diff --git a/integration/quantization_test.go b/integration/quantization_test.go index af9da0b62..305647496 100644 --- a/integration/quantization_test.go +++ b/integration/quantization_test.go @@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) { stream := true genReq := api.GenerateRequest{ Model: newName, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, KeepAlive: &api.Duration{Duration: 3 * time.Second}, Options: map[string]any{ "seed": 42, @@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) { // Some smaller quantizations can cause models to have poor quality // or get stuck in repetition loops, so we stop as soon as we have any matches - anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"} reqCtx, reqCancel := context.WithCancel(ctx) atLeastOne := false var buf bytes.Buffer genfn := func(response api.GenerateResponse) error { buf.Write([]byte(response.Response)) fullResp := strings.ToLower(buf.String()) - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(fullResp, resp) { atLeastOne = true t.Log(fullResp) diff --git a/integration/utils_test.go b/integration/utils_test.go index d7e3790b1..554b02709 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log/slog" + "math" "math/rand" "net" "net/http" @@ -25,11 +26,11 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/format" - "github.com/stretchr/testify/require" ) var ( - smol = "llama3.2:1b" + smol = "llama3.2:1b" + stream = false ) var ( @@ -255,13 +256,28 @@ var ( "snowflake-arctic-embed", "snowflake-arctic-embed2", } + + blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply" + blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"} + + rainbowPrompt = "how do rainbows form? Be brief but factual in your reply" + rainbowFollowups = []string{ + "Explain the physics involved in them. Be breif in your reply", + "Explain the chemistry involved in them. Be breif in your reply", + "What are common myths related to them? Be brief in your reply", + "What are common fairytales related to them? Be brief in your reply", + "Can they form if there is no rain? Be breif in your reply", + "Can they form if there are no clouds? Be breif in your reply", + "Do they happen on other planets? Be brief in your reply", + } + rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "end", "gold", "fortune", "blessing", "prosperity", "magic", "shower", "sky", "shimmer", "light", "storm", "sunny"} ) func init() { lifecycle.InitLogging() - custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL") + custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL") if custom != "" { - slog.Info("setting smol test model to " + custom) + slog.Info("setting default test model to " + custom) smol = custom } } @@ -435,7 +451,27 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin } lifecycle.ServerLogFile = fp.Name() fp.Close() - require.NoError(t, startServer(t, ctx, testEndpoint)) + if err := startServer(t, ctx, testEndpoint); err != nil { + t.Fatal(err) + } + } + // Make sure server is online and healthy before returning + listCtx, cancel := context.WithDeadlineCause( + ctx, + time.Now().Add(120*time.Second), + fmt.Errorf("list models took too long"), + ) + defer cancel() + models, err := client.ListRunning(listCtx) + if err != nil { + t.Fatal(err) + } + if len(models.Models) > 0 { + names := make([]string, len(models.Models)) + for i, m := range models.Models { + names[i] = m.Name + } + slog.Info("currently loaded", "models", names) } return client, testEndpoint, func() { @@ -468,7 +504,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, genReq.Model)) + if err := PullIfMissing(ctx, client, genReq.Model); err != nil { + t.Fatal(err) + } DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) } @@ -497,6 +535,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap done <- 0 }() + var response string + verify := func() { + // Verify the response contains the expected data + response = buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + if !atLeastOne { + t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response) + } + } + select { case <-stallTimer.C: if buf.Len() == 0 { @@ -509,20 +563,17 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr) return context } - require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) - // Verify the response contains the expected data - response := buf.String() - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(response), resp) { - atLeastOne = true - break - } + if genErr != nil { + t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt) } - require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response) + verify() slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) case <-ctx.Done(): - t.Error("outer test context done while waiting for generate") + // On slow systems, we might timeout before some models finish rambling, so check what we have so far to see + // if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass + // if they are still generating valid responses + slog.Warn("outer test context done while waiting for generate") + verify() } return context } @@ -543,7 +594,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { Model: smol, - Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply", + Prompt: rainbowPrompt, Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { @@ -559,19 +610,106 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { }, }, [][]string{ - {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, - {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, - {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"}, + {"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"}, + {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"}, + rainbowExpected, {"fourth", "july", "declaration", "independence"}, - {"nitrogen", "oxygen", "carbon", "dioxide"}, + {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"}, } } +func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message { + stallTimer := time.NewTimer(initialTimeout) + var buf bytes.Buffer + role := "assistant" + fn := func(response api.ChatResponse) error { + // fmt.Print(".") + role = response.Message.Role + buf.Write([]byte(response.Message.Content)) + if !stallTimer.Reset(streamTimeout) { + return errors.New("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + req.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Chat(ctx, &req, fn) + done <- 0 + }() + + var response string + verify := func() { + // Verify the response contains the expected data + response = buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + if !atLeastOne { + t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages) + } + } + + select { + case <-stallTimer.C: + if buf.Len() == 0 { + t.Errorf("generate never started. Timed out after :%s", initialTimeout.String()) + } else { + t.Errorf("generate stalled. Response so far:%s", buf.String()) + } + case <-done: + if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") { + slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr) + return nil + } + if genErr != nil { + t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages) + } + verify() + slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response) + case <-ctx.Done(): + // On slow systems, we might timeout before some models finish rambling, so check what we have so far to see + // if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass + // if they are still generating valid responses + slog.Warn("outer test context done while waiting for chat") + verify() + } + return &api.Message{Role: role, Content: buf.String()} +} + +func ChatRequests() ([]api.ChatRequest, [][]string) { + genReqs, results := GenerateRequests() + reqs := make([]api.ChatRequest, len(genReqs)) + // think := api.ThinkValue{Value: "low"} + for i := range reqs { + reqs[i].Model = genReqs[i].Model + reqs[i].Stream = genReqs[i].Stream + reqs[i].KeepAlive = genReqs[i].KeepAlive + // reqs[i].Think = &think + reqs[i].Messages = []api.Message{ + { + Role: "user", + Content: genReqs[i].Prompt, + }, + } + } + return reqs, results +} + func skipUnderMinVRAM(t *testing.T, gb uint64) { // TODO use info API in the future if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { maxVram, err := strconv.ParseUint(s, 10, 64) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // Don't hammer on small VRAM cards... if maxVram < gb*format.GibiByte { t.Skip("skipping with small VRAM to avoid timeouts") @@ -579,6 +717,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) { } } +// Skip if the target model isn't X% GPU loaded to avoid excessive runtime +func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) { + models, err := client.ListRunning(ctx) + if err != nil { + t.Fatalf("failed to list running models: %s", err) + } + loaded := []string{} + for _, m := range models.Models { + loaded = append(loaded, m.Name) + if m.Name != model { + continue + } + gpuPercent := 0 + switch { + case m.SizeVRAM == 0: + gpuPercent = 0 + case m.SizeVRAM == m.Size: + gpuPercent = 100 + case m.SizeVRAM > m.Size || m.Size == 0: + t.Logf("unexpected size detected: %d", m.SizeVRAM) + default: + sizeCPU := m.Size - m.SizeVRAM + cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110) + gpuPercent = int(100 - cpuPercent) + } + if gpuPercent < minPercent { + t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent)) + } + return + } + t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded)) +} + func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) { deadline, hasDeadline := t.Deadline() if !hasDeadline { diff --git a/llama/llama.go b/llama/llama.go index ac2c112c2..90b462703 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -42,6 +42,7 @@ import ( _ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/src" _ "github.com/ollama/ollama/llama/llama.cpp/tools/mtmd" + "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) @@ -62,8 +63,8 @@ func BackendInit() { C.llama_backend_init() } -func EnumerateGPUs() []string { - var ids []string +func EnumerateGPUs() []ml.DeviceID { + var ids []ml.DeviceID for i := range C.ggml_backend_dev_count() { device := C.ggml_backend_dev_get(i) @@ -71,7 +72,10 @@ func EnumerateGPUs() []string { if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU { var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(device, &props) - ids = append(ids, C.GoString(props.id)) + ids = append(ids, ml.DeviceID{ + ID: C.GoString(props.id), + Library: C.GoString(props.library), + }) } } @@ -515,33 +519,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, } nChunks := C.mtmd_input_chunks_size(ic) numEmbed := llamaContext.Model().NEmbd() - lastChunkSize := 0 + embed := make([][]float32, 0) for i := range int(nChunks) { chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) - lastChunkSize = numTokens + slog.Debug("chunk tokens", "index", i, "numTokens", numTokens) // Encode the chunk if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { return nil, errors.New("unable to encode mtmd image chunk") } - } - // Get the embeddings - embed := make([][]float32, lastChunkSize) - embd := C.mtmd_get_output_embd(c.c) - if nil == embd { - return nil, errors.New("failed to get image embedding") - } + // Get the embeddings for this chunk + chunkEmbed := make([][]float32, numTokens) + chunkEmbd := C.mtmd_get_output_embd(c.c) + if nil == chunkEmbd { + continue + } - // Extend the embedding array for each token - s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) - rows := make([]float32, len(s)) - copy(rows, s) - for i := range lastChunkSize { - embed[i] = rows[i*numEmbed : (i+1)*numEmbed] + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) + rows := make([]float32, len(s)) + copy(rows, s) + for i := range numTokens { + chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + embed = append(embed, chunkEmbed...) } - + slog.Debug("image embeddings", "totalEmbeddings", len(embed)) return embed, nil } diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0014-graph-memory-reporting-on-failure.patch index 26fe8a8e0..a9fc420f1 100644 --- a/llama/patches/0014-graph-memory-reporting-on-failure.patch +++ b/llama/patches/0014-graph-memory-reporting-on-failure.patch @@ -4,48 +4,38 @@ Date: Fri, 18 Apr 2025 15:58:19 -0700 Subject: [PATCH] graph memory reporting on failure --- - ggml/include/ggml-alloc.h | 6 ++++++ - ggml/include/ggml-backend.h | 6 ++++++ - ggml/src/ggml-alloc.c | 38 +++++++++++++++++++++++++++++++++---- - ggml/src/ggml-backend.cpp | 10 ++++++++++ - 4 files changed, 56 insertions(+), 4 deletions(-) + ggml/include/ggml-alloc.h | 1 + + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-alloc.c | 36 ++++++++++++++++++++++++++++++++---- + ggml/src/ggml-backend.cpp | 7 +++++++ + 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h -index 2cb150fd..781b1e10 100644 +index 2cb150fd2..7ab3f0192 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h -@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph +@@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n( + GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); ++GGML_API size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); -+struct ggml_allocr_buffer_status { -+ size_t size; -+ bool allocated; -+}; -+GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); -+ // Utils // Create a buffer and allocate all the tensors in a ggml_context - GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index a2977ea2..8a91b381 100644 +index a2977ea2e..e8cf30841 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -304,6 +304,12 @@ extern "C" { +@@ -303,6 +303,7 @@ extern "C" { + GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); ++ GGML_API size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); -+ struct ggml_backend_buffer_status { -+ size_t size; -+ bool allocated; -+ }; -+ GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); -+ GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); - diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c -index 8b6e6028..41c8c4a2 100644 +index 8b6e60283..b58bd671d 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -350,6 +350,7 @@ struct node_alloc { @@ -108,11 +98,11 @@ index 8b6e6028..41c8c4a2 100644 } bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { -@@ -920,6 +932,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { +@@ -920,6 +932,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); } -+struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { ++size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); + + for (int i = 0; i < buffer_id; i++) { @@ -121,34 +111,29 @@ index 8b6e6028..41c8c4a2 100644 + // (See above.) However, we need a different check because multiple buffers might be NULL in our + // case and we still want to know the attempted size. + -+ struct ggml_allocr_buffer_status status = {0, true}; -+ return status; ++ return 0; + } + } + -+ struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; -+ return status; ++ return galloc->buffer_sizes[buffer_id]; +} + // utils static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 97f47abd..eded0291 100644 +index 97f47abd2..d02a40e60 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -1631,6 +1631,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe +@@ -1631,6 +1631,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } -+struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { ++size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + -+ struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); -+ struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; -+ -+ return status; ++ return ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); +} + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/llama/patches/0022-ggml-No-alloc-mode.patch b/llama/patches/0022-ggml-No-alloc-mode.patch index fa738452f..59df80fdb 100644 --- a/llama/patches/0022-ggml-No-alloc-mode.patch +++ b/llama/patches/0022-ggml-No-alloc-mode.patch @@ -3,35 +3,45 @@ From: Jesse Gross Date: Wed, 23 Jul 2025 11:58:49 -0700 Subject: [PATCH] ggml: No-alloc mode -Callers can set a backend buffer type to be no-alloc, meaning that +Callers can set a scheduler to be no-alloc, meaning that it does not allocate memory for tensors or operations. This can be used for calculating memory requirements. Tensors and graphs must be recreated with no-alloc set to false before loading data. - -Defaults to false for newly created backend buffer types. --- - ggml/include/ggml-backend.h | 1 + - ggml/src/ggml-backend-impl.h | 2 ++ - ggml/src/ggml-backend.cpp | 19 ++++++++++++++++++- - 3 files changed, 21 insertions(+), 1 deletion(-) + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-backend-impl.h | 16 +++ + ggml/src/ggml-backend.cpp | 72 ++++++++++- + ggml/src/ggml-cuda/common.cuh | 48 ++++++- + ggml/src/ggml-cuda/ggml-cuda.cu | 217 ++++++++++++++++++++++++++------ + 5 files changed, 310 insertions(+), 44 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 9424394e..b602a7c7 100644 +index 2773cc310..ae94887dd 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -35,6 +35,7 @@ extern "C" { - // +@@ -291,6 +291,7 @@ extern "C" { - GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); -+ GGML_API void ggml_backend_buft_set_alloc (ggml_backend_buffer_type_t buft, bool alloc); - GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); - GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); - GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); + // Initialize a backend scheduler, backends with low index are given priority over backends with high index + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); ++ GGML_API ggml_backend_sched_t ggml_backend_sched_new_ext(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload, bool alloc_buffers); + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + + // Initialize backend buffers from a measure graph diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index c36c12d6..81749a5a 100644 +index c36c12d65..369e9e25a 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h -@@ -32,6 +32,7 @@ extern "C" { +@@ -26,12 +26,17 @@ extern "C" { + size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); + // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) + bool (*is_host) (ggml_backend_buffer_type_t buft); ++ ++ // (optional) returns a dummy buffer that is equivalent to one created by alloc_buffer but without actually being backed ++ // by memory ++ ggml_backend_buffer_t (*noalloc_buffer)(ggml_backend_buffer_type_t buft, size_t size); + }; + + struct ggml_backend_buffer_type { struct ggml_backend_buffer_type_i iface; ggml_backend_dev_t device; void * context; @@ -39,7 +49,7 @@ index c36c12d6..81749a5a 100644 }; // -@@ -63,6 +64,7 @@ extern "C" { +@@ -63,6 +68,7 @@ extern "C" { void * context; size_t size; enum ggml_backend_buffer_usage usage; @@ -47,26 +57,40 @@ index c36c12d6..81749a5a 100644 }; GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( +@@ -114,6 +120,16 @@ extern "C" { + void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); + // wait for an event on on a different stream + void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); ++ ++ // (optional) reserves intermediate buffers needed for the compution ++ // if alloc is true, memory is actually allocated, otherwise the required amount is just returned by buffer_size ++ enum ggml_status (*graph_reserve) (ggml_backend_t backend, struct ggml_cgraph * cgraph, bool alloc); ++ ++ // (optional) returns the memory needed after calling graph_reserve ++ size_t (*buffer_size) (ggml_backend_t backend); ++ ++ // (optional) frees memory from intermediate buffers that was allocated either by graph_compute or graph_reserve ++ void (*reset) (ggml_backend_t backend); + }; + + struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index eded0291..05a842ed 100644 +index d02a40e60..6b4dee4c7 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -35,12 +35,22 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name(buft); - } - -+void ggml_backend_buft_set_alloc(ggml_backend_buffer_type_t buft, bool alloc) { -+ buft->no_alloc = !alloc; -+} -+ - ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - if (size == 0) { - // return a dummy buffer for zero-sized allocations +@@ -41,6 +41,19 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t return ggml_backend_buffer_init(buft, {}, NULL, 0); } + if (buft->no_alloc) { -+ ggml_backend_buffer_t buf = ggml_backend_buffer_init(buft, {}, NULL, size); ++ ggml_backend_buffer_t buf; ++ ++ if (buft->iface.noalloc_buffer != NULL) { ++ buf = buft->iface.noalloc_buffer(buft, size); ++ } else { ++ buf = ggml_backend_buffer_init(buft, {}, NULL, size); ++ } ++ + buf->no_alloc = true; + return buf; + } @@ -74,7 +98,7 @@ index eded0291..05a842ed 100644 return buft->iface.alloc_buffer(buft, size); } -@@ -89,7 +99,8 @@ ggml_backend_buffer_t ggml_backend_buffer_init( +@@ -89,7 +102,8 @@ ggml_backend_buffer_t ggml_backend_buffer_init( /* .buft = */ buft, /* .context = */ context, /* .size = */ size, @@ -84,7 +108,7 @@ index eded0291..05a842ed 100644 }; return buffer; -@@ -119,6 +130,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { +@@ -119,6 +133,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return NULL; } @@ -97,3 +121,532 @@ index eded0291..05a842ed 100644 void * base = buffer->iface.get_base(buffer); GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); +@@ -663,6 +683,12 @@ struct ggml_backend_sched { + bool op_offload; + + int debug; ++ ++ // allocate buffers on attached ggml_backend_buffer_type_t's and during reservation ++ // if false, dummy buffers are used for faster memory sizing calculations ++ // the scheduler needs to be recreated with allocated buffers before it can be used ++ // for computation ++ bool alloc_buffers; + }; + + #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) +@@ -1449,6 +1475,17 @@ ggml_backend_sched_t ggml_backend_sched_new( + size_t graph_size, + bool parallel, + bool op_offload) { ++ return ggml_backend_sched_new_ext(backends, bufts, n_backends, graph_size, parallel, op_offload, true); ++ } ++ ++ggml_backend_sched_t ggml_backend_sched_new_ext( ++ ggml_backend_t * backends, ++ ggml_backend_buffer_type_t * bufts, ++ int n_backends, ++ size_t graph_size, ++ bool parallel, ++ bool op_offload, ++ bool alloc_buffers) { + GGML_ASSERT(n_backends > 0); + GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); + GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); +@@ -1490,10 +1527,13 @@ ggml_backend_sched_t ggml_backend_sched_new( + sched->events[b][c] = ggml_backend_event_new(backends[b]->device); + } + } ++ ++ sched->bufts[b]->no_alloc = !alloc_buffers; + } + + sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->op_offload = op_offload; ++ sched->alloc_buffers = alloc_buffers; + + ggml_backend_sched_reset(sched); + +@@ -1508,6 +1548,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { + for (int c = 0; c < sched->n_copies; c++) { + ggml_backend_event_free(sched->events[b][c]); + } ++ ++ if (sched->backends[b]->iface.reset != NULL) { ++ sched->backends[b]->iface.reset(sched->backends[b]); ++ } + } + ggml_gallocr_free(sched->galloc); + ggml_free(sched->ctx); +@@ -1547,6 +1591,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * + return false; + } + ++ if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { ++ return false; ++ } ++ ++ struct ggml_backend_sched_split * splits = sched->splits; ++ for (int i = 0; i < sched->n_splits; i++) { ++ struct ggml_backend_sched_split * split = &splits[i]; ++ int split_backend_id = split->backend_id; ++ ggml_backend_t split_backend = sched->backends[split_backend_id]; ++ ++ if (split_backend->iface.graph_reserve != NULL) { ++ enum ggml_status ec = split_backend->iface.graph_reserve(split_backend, &split->graph, sched->alloc_buffers); ++ if (ec != GGML_STATUS_SUCCESS) { ++ return false; ++ } ++ } ++ } ++ + ggml_backend_sched_reset(sched); + + return true; +@@ -1635,7 +1697,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + +- return ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); ++ size_t size = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); ++ ++ if (backend->iface.buffer_size != NULL) { ++ size += backend->iface.buffer_size(backend); ++ } ++ ++ return size; + } + + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { +diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh +index 2e5d48797..b915ee1b8 100644 +--- a/ggml/src/ggml-cuda/common.cuh ++++ b/ggml/src/ggml-cuda/common.cuh +@@ -35,6 +35,31 @@ + #include "vendors/cuda.h" + #endif // defined(GGML_USE_HIP) + ++extern bool reserving_graph; ++ ++// If we are reserving the graph, pointers might be invalid and will fail if cudaMemcpyAsync tries to validate them. ++// However, since we don't actually expect a result, we don't need to actually do the memcpy. ++static cudaError_t cudaMemcpyAsyncReserve ( void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { ++ if (!reserving_graph) { ++ return cudaMemcpyAsync(dst, src, count, kind, stream); ++ } else { ++ return cudaSuccess; ++ } ++} ++ ++static cudaError_t cudaMemcpy2DAsyncReserve ( void* dst, size_t dpitch, const void* src, size_t spitch, size_t width, size_t height, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { ++ if (!reserving_graph) { ++ return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, kind, stream); ++ } else { ++ return cudaSuccess; ++ } ++} ++ ++#undef cudaMemcpyAsync ++#define cudaMemcpyAsync cudaMemcpyAsyncReserve ++#undef cudaMemcpy2DAsync ++#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve ++ + #define STRINGIZE_IMPL(...) #__VA_ARGS__ + #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) + +@@ -771,6 +796,9 @@ struct ggml_cuda_pool { + + virtual void * alloc(size_t size, size_t * actual_size) = 0; + virtual void free(void * ptr, size_t size) = 0; ++ ++ virtual bool alloc_memory() = 0; ++ virtual size_t alloc_size() = 0; + }; + + template +@@ -914,11 +942,11 @@ struct ggml_backend_cuda_context { + // pool + std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; + +- static std::unique_ptr new_pool_for_device(int device); ++ static std::unique_ptr new_pool_for_device(int device, bool alloc); + + ggml_cuda_pool & pool(int device) { + if (pools[device] == nullptr) { +- pools[device] = new_pool_for_device(device); ++ pools[device] = new_pool_for_device(device, true); + } + return *pools[device]; + } +@@ -926,4 +954,20 @@ struct ggml_backend_cuda_context { + ggml_cuda_pool & pool() { + return pool(device); + } ++ ++ void pool_set_alloc(bool alloc) { ++ GGML_ASSERT(pools[device] == nullptr || pools[device]->alloc_memory() == alloc); ++ ++ if (pools[device] == nullptr) { ++ pools[device] = new_pool_for_device(device, alloc); ++ } ++ } ++ ++ size_t pool_get_alloc_size() { ++ if (pools[device] == nullptr) { ++ return 0; ++ } ++ ++ return pools[device]->alloc_size(); ++ } + }; +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index c7f9dc3a5..d5abe09e0 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { + + // #define DEBUG_CUDA_MALLOC + ++#define CUDA_ALIGNMENT 128 ++ + // buffer pool for cuda (legacy) + struct ggml_cuda_pool_leg : public ggml_cuda_pool { + static const int MAX_BUFFERS = 256; +@@ -362,9 +364,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + + ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; + size_t pool_size = 0; ++ bool allocate = true; ++ size_t last_alloc = 0; + +- explicit ggml_cuda_pool_leg(int device) : +- device(device) { ++ explicit ggml_cuda_pool_leg(int device, bool alloc) : ++ device(device), ++ allocate(alloc) { + } + + ~ggml_cuda_pool_leg() { +@@ -372,7 +377,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + for (int i = 0; i < MAX_BUFFERS; ++i) { + ggml_cuda_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { +- CUDA_CHECK(cudaFree(b.ptr)); ++ if (allocate) { ++ CUDA_CHECK(cudaFree(b.ptr)); ++ } + pool_size -= b.size; + } + } +@@ -420,8 +427,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); +- ggml_cuda_set_device(device); +- CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); ++ if (allocate) { ++ ggml_cuda_set_device(device); ++ if (ggml_cuda_device_malloc(&ptr, look_ahead_size, device) != cudaSuccess) { ++ last_alloc = look_ahead_size; ++ throw std::bad_alloc(); ++ } ++ } else { ++ ptr = (void *)CUDA_ALIGNMENT; ++ } + *actual_size = look_ahead_size; + pool_size += look_ahead_size; + #ifdef DEBUG_CUDA_MALLOC +@@ -441,10 +455,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + } + } + GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); +- ggml_cuda_set_device(device); +- CUDA_CHECK(cudaFree(ptr)); ++ if (allocate) { ++ ggml_cuda_set_device(device); ++ CUDA_CHECK(cudaFree(ptr)); ++ } + pool_size -= size; + } ++ ++ bool alloc_memory() override { ++ return allocate; ++ } ++ ++ size_t alloc_size() override { ++ return pool_size + last_alloc; ++ } + }; + + // pool with virtual memory +@@ -456,18 +480,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + CUdeviceptr pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; ++ bool allocate = true; ++ size_t last_alloc = 0; + size_t granularity; + #if defined(GGML_USE_HIP) + std::vector> mappings; + #endif + +- explicit ggml_cuda_pool_vmm(int device) : ++ explicit ggml_cuda_pool_vmm(int device, bool alloc) : + device(device), +- granularity(ggml_cuda_info().devices[device].vmm_granularity) { ++ granularity(ggml_cuda_info().devices[device].vmm_granularity), ++ allocate(alloc) { ++ if (!allocate) { ++ pool_addr = (CUdeviceptr)CUDA_ALIGNMENT; ++ } + } + + ~ggml_cuda_pool_vmm() { +- if (pool_addr != 0) { ++ if (pool_addr != 0 && allocate) { + #if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { +@@ -494,35 +524,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + + GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); + +- // allocate more physical memory +- CUmemAllocationProp prop = {}; +- prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; +- prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; +- prop.location.id = device; +- CUmemGenericAllocationHandle handle; +- CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); +- +- // reserve virtual address space (if not already reserved) +- if (pool_addr == 0) { +- CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); +- } ++ if (allocate) { ++ // allocate more physical memory ++ CUmemAllocationProp prop = {}; ++ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; ++ prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; ++ prop.location.id = device; ++ CUmemGenericAllocationHandle handle; ++ if (cuMemCreate(&handle, reserve_size, &prop, 0) != CUDA_SUCCESS) { ++ last_alloc = reserve_size; ++ throw std::bad_alloc(); ++ } + +- // map at the end of the pool +- CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); +- CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +-#if defined(GGML_USE_HIP) +- mappings.push_back({start_ptr, reserve_size}); +-#endif ++ // reserve virtual address space (if not already reserved) ++ if (pool_addr == 0) { ++ CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); ++ } + +- // the memory allocation handle is no longer needed after mapping +- CU_CHECK(cuMemRelease(handle)); ++ // map at the end of the pool ++ CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); ++ if (cuMemMap(start_ptr, reserve_size, 0, handle, 0) != CUDA_SUCCESS) { ++ last_alloc = reserve_size; ++ CU_CHECK(cuMemRelease(handle)); ++ throw std::bad_alloc(); ++ } ++ ++ // the memory allocation handle is no longer needed after mapping ++ CU_CHECK(cuMemRelease(handle)); ++ ++ // set access ++ CUmemAccessDesc access = {}; ++ access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; ++ access.location.id = device; ++ access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; ++ if (cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1) != CUDA_SUCCESS) { ++ CU_CHECK(cuMemUnmap(start_ptr, reserve_size)); ++ last_alloc = reserve_size; ++ throw std::bad_alloc(); ++ } + +- // set access +- CUmemAccessDesc access = {}; +- access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; +- access.location.id = device; +- access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; +- CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); ++ #if defined(GGML_USE_HIP) ++ mappings.push_back({start_ptr, reserve_size}); ++ #endif ++ } + + // add to the pool + pool_size += reserve_size; +@@ -555,16 +599,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); + } ++ ++ bool alloc_memory() override { ++ return allocate; ++ } ++ ++ size_t alloc_size() override { ++ return pool_size + last_alloc; ++ } + }; + #endif // defined(GGML_USE_VMM) + +-std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { ++std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, bool alloc) { + #if defined(GGML_USE_VMM) + if (ggml_cuda_info().devices[device].vmm) { +- return std::unique_ptr(new ggml_cuda_pool_vmm(device)); ++ return std::unique_ptr(new ggml_cuda_pool_vmm(device, alloc)); + } + #endif // defined(GGML_USE_VMM) +- return std::unique_ptr(new ggml_cuda_pool_leg(device)); ++ return std::unique_ptr(new ggml_cuda_pool_leg(device, alloc)); + } + + // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error +@@ -748,11 +800,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac + } + + static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +- return 128; ++ return CUDA_ALIGNMENT; + + GGML_UNUSED(buft); + } + ++static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_noalloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ++ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; ++ ++ void * dev_ptr = (void *)ggml_backend_cuda_buffer_type_get_alignment(buft); ++ ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); ++ ++ return ggml_backend_buffer_init(buft, {}, ctx, size); ++} ++ + static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t size = ggml_nbytes(tensor); + int64_t ne0 = tensor->ne[0]; +@@ -776,6 +837,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, + /* .is_host = */ NULL, ++ /* .noalloc_buffer = */ ggml_backend_cuda_buffer_type_noalloc_buffer, + }; + + ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { +@@ -2936,6 +2998,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, + + static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { ++ + // flag used to determine whether it is an integrated_gpu + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + +@@ -2951,6 +3014,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx + continue; + } + ++ // When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it ++ if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { ++ continue; ++ } ++ + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion) { + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { +@@ -3022,6 +3090,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx + + static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ++ cuda_ctx->pool_set_alloc(true); + + ggml_cuda_set_device(cuda_ctx->device); + +@@ -3101,6 +3170,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, + return GGML_STATUS_SUCCESS; + } + ++// This is used to skip operations that are not graph safe during the reservation process. ++bool reserving_graph = false; ++ ++static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) { ++ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ++ cuda_ctx->pool_set_alloc(alloc); ++ ++ #ifdef USE_CUDA_GRAPH ++ if (cuda_ctx->cuda_graph == nullptr) { ++ cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); ++ } ++ #endif ++ ++ ggml_cuda_set_device(cuda_ctx->device); ++ ++ { ++ std::lock_guard lock(ggml_cuda_lock); ++ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); ++ } ++ ++ reserving_graph = true; ++ ++ // Create CuBLAS handles early to avoid synchronous allocations during graph capture. ++ cuda_ctx->cublas_handle(); ++ ++ CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); ++ ++ enum ggml_status result = GGML_STATUS_SUCCESS; ++ ++ try { ++ bool use_cuda_graph = false; ++ bool cuda_graph_update_required = false; ++ bool graph_evaluated_or_captured = false; ++ ++ evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); ++ } catch (const std::exception &e) { ++ result = GGML_STATUS_FAILED; ++ } ++ ++ cudaGraph_t graph; ++ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph)); ++ CUDA_CHECK(cudaGraphDestroy(graph)); ++ ++ reserving_graph = false; ++ ++ { ++ std::lock_guard lock(ggml_cuda_lock); ++ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { ++ ggml_cuda_lock_cv.notify_all(); ++ } ++ } ++ ++ return result; ++} ++ ++static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) { ++ ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; ++ return ctx->pool_get_alloc_size(); ++} ++ ++static void ggml_backend_cuda_reset(ggml_backend_t backend) { ++ ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; ++ ctx->pools[ctx->device] = NULL; ++} ++ + static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + +@@ -3140,6 +3274,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { + /* .graph_compute = */ ggml_backend_cuda_graph_compute, + /* .event_record = */ ggml_backend_cuda_event_record, + /* .event_wait = */ ggml_backend_cuda_event_wait, ++ /* .graph_reserve = */ ggml_backend_cuda_graph_reserve, ++ /* .buffer_size = */ ggml_backend_cuda_buffer_size, ++ /* .reset = */ ggml_backend_cuda_reset, + }; + + static ggml_guid_t ggml_backend_cuda_guid() { diff --git a/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch new file mode 100644 index 000000000..84aefd1df --- /dev/null +++ b/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch @@ -0,0 +1,130 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jesse Gross +Date: Wed, 27 Aug 2025 14:39:48 -0700 +Subject: [PATCH] ggml: Enable resetting backend devices + +Touching a CUDA device causes the allocation of a primary context +with CUDA data structures (~300 MB of VRAM). If a device is +unused then it can be reset to free these data structures. +--- + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-backend-impl.h | 4 ++++ + ggml/src/ggml-backend.cpp | 8 ++++++++ + ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++-- + ggml/src/ggml-cuda/vendors/hip.h | 1 + + 5 files changed, 29 insertions(+), 2 deletions(-) + +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index b602a7c78..fda5ceb24 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -167,6 +167,7 @@ extern "C" { + GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); + GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); + GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); ++ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); +diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h +index 81749a5a3..6f10c353b 100644 +--- a/ggml/src/ggml-backend-impl.h ++++ b/ggml/src/ggml-backend-impl.h +@@ -178,6 +178,10 @@ extern "C" { + ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev); + void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event); + void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event); ++ ++ // (optional) reset device, clearing existing allocations and context ++ // the caller must ensure that there are no outstanding buffers, as these will become invalid ++ void (*reset)(ggml_backend_dev_t dev); + }; + + struct ggml_backend_device { +diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp +index 05a842ed5..6556943b0 100644 +--- a/ggml/src/ggml-backend.cpp ++++ b/ggml/src/ggml-backend.cpp +@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par + return device->iface.init_backend(device, params); + } + ++void ggml_backend_dev_reset(ggml_backend_dev_t device) { ++ if (device->iface.reset == NULL) { ++ return; ++ } ++ ++ device->iface.reset(device); ++} ++ + ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { + return device->iface.get_buffer_type(device); + } +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index c7f9dc3a5..e43fde523 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -103,6 +103,11 @@ int ggml_cuda_get_device() { + return id; + } + ++void ggml_cuda_reset_device(int device) { ++ ggml_cuda_set_device(device); ++ CUDA_CHECK(cudaDeviceReset()); ++} ++ + static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { + ggml_cuda_set_device(device); + cudaError_t err; +@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back + props->description = ggml_backend_cuda_device_get_description(dev); + props->id = ggml_backend_cuda_device_get_id(dev); + props->type = ggml_backend_cuda_device_get_type(dev); +- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); ++ ++ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device). ++ // If you need the memory data, call ggml_backend_dev_memory() explicitly. ++ props->memory_total = props->memory_free = 0; + + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; + #ifdef GGML_CUDA_NO_PEER_COPY +@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g + CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); + } + ++static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) { ++ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ++ ggml_cuda_reset_device(ctx->device); ++} ++ + static const ggml_backend_device_i ggml_backend_cuda_device_interface = { + /* .get_name = */ ggml_backend_cuda_device_get_name, + /* .get_description = */ ggml_backend_cuda_device_get_description, +@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { + /* .event_new = */ ggml_backend_cuda_device_event_new, + /* .event_free = */ ggml_backend_cuda_device_event_free, + /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, ++ /* .reset = */ ggml_backend_cuda_device_reset, + }; + + // backend reg +@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + dev_ctx->device = i; + dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); + +- ggml_cuda_set_device(i); + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); + dev_ctx->description = prop.name; +diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h +index c31f31923..cf22e60d2 100644 +--- a/ggml/src/ggml-cuda/vendors/hip.h ++++ b/ggml/src/ggml-cuda/vendors/hip.h +@@ -40,6 +40,7 @@ + #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess + #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess + #define cudaDeviceProp hipDeviceProp_t ++#define cudaDeviceReset hipDeviceReset + #define cudaDeviceSynchronize hipDeviceSynchronize + #define cudaError_t hipError_t + #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled diff --git a/llama/patches/0025-harden-uncaught-exception-registration.patch b/llama/patches/0025-harden-uncaught-exception-registration.patch new file mode 100644 index 000000000..d5fc2598c --- /dev/null +++ b/llama/patches/0025-harden-uncaught-exception-registration.patch @@ -0,0 +1,28 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Fri, 29 Aug 2025 16:53:08 -0700 +Subject: [PATCH] harden uncaught exception registration + +--- + ggml/src/ggml.cpp | 8 ++++++-- + 1 file changed, 6 insertions(+), 2 deletions(-) + +diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp +index 0d388d45..f5bcb446 100644 +--- a/ggml/src/ggml.cpp ++++ b/ggml/src/ggml.cpp +@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{ + return false; + } + const auto prev{std::get_terminate()}; +- GGML_ASSERT(prev != ggml_uncaught_exception); +- previous_terminate_handler = prev; ++ // GGML_ASSERT(prev != ggml_uncaught_exception); ++ if (prev != ggml_uncaught_exception) { ++ previous_terminate_handler = prev; ++ } else { ++ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__); ++ } + std::set_terminate(ggml_uncaught_exception); + return true; + }(); diff --git a/llama/patches/0026-GPU-discovery-enhancements.patch b/llama/patches/0026-GPU-discovery-enhancements.patch new file mode 100644 index 000000000..534a5a386 --- /dev/null +++ b/llama/patches/0026-GPU-discovery-enhancements.patch @@ -0,0 +1,876 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Tue, 26 Aug 2025 12:48:29 -0700 +Subject: [PATCH] GPU discovery enhancements + +Expose more information about the devices through backend props, and leverage +management libraries for more accurate VRAM usage reporting if available. +--- + ggml/include/ggml-backend.h | 9 + + ggml/src/CMakeLists.txt | 2 + + ggml/src/ggml-cuda/ggml-cuda.cu | 75 +++++- + ggml/src/ggml-cuda/vendors/hip.h | 1 + + ggml/src/ggml-impl.h | 8 + + ggml/src/ggml-metal/ggml-metal.m | 2 + + ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++++ + ggml/src/mem_nvml.cpp | 172 ++++++++++++ + 8 files changed, 717 insertions(+), 1 deletion(-) + create mode 100644 ggml/src/mem_hip.cpp + create mode 100644 ggml/src/mem_nvml.cpp + +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index fda5ceb24..7c2d86703 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -158,6 +158,15 @@ extern "C" { + size_t memory_total; + enum ggml_backend_dev_type type; + struct ggml_backend_dev_caps caps; ++ int driver_major; ++ int driver_minor; ++ int compute_major; ++ int compute_minor; ++ int integrated; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; ++ const char *library; + }; + + GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); +diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt +index 5158acd6a..3a428a22d 100644 +--- a/ggml/src/CMakeLists.txt ++++ b/ggml/src/CMakeLists.txt +@@ -203,6 +203,8 @@ add_library(ggml-base + ggml-threading.h + ggml-quants.c + ggml-quants.h ++ mem_hip.cpp ++ mem_nvml.cpp + gguf.cpp) + + target_include_directories(ggml-base PRIVATE .) +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index e43fde523..14baf0fb1 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -279,6 +279,16 @@ static ggml_cuda_device_info ggml_cuda_init() { + for (int id = 0; id < info.device_count; ++id) { + int device_vmm = 0; + ++#if defined(GGML_USE_HIP) ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_LOG_INFO("%s: initializing rocBLAS on device %d\n", __func__, id); ++ CUDA_CHECK(cudaSetDevice(id)); ++ // rocblas_initialize will SIGABRT if the GPU isn't supported ++ rocblas_initialize(); ++ GGML_LOG_INFO("%s: rocBLAS initialized on device %d\n", __func__, id); ++ } ++#endif ++ + #if defined(GGML_USE_VMM) + CUdevice device; + CU_CHECK(cuDeviceGet(&device, id)); +@@ -332,9 +342,15 @@ static ggml_cuda_device_info ggml_cuda_init() { + #else + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = 100*prop.major + 10*prop.minor; ++#ifdef __CUDA_ARCH_LIST__ ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); ++ } ++#endif // defined(__CUDA_ARCH_LIST__) + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + ggml_cuda_parse_uuid(prop, id).c_str()); ++ + #endif // defined(GGML_USE_HIP) + } + +@@ -3215,6 +3231,14 @@ struct ggml_backend_cuda_device_context { + std::string name; + std::string description; + std::string id; ++ int major; ++ int minor; ++ int driver_major; ++ int driver_minor; ++ int integrated; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; + }; + + static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { +@@ -3235,6 +3259,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_set_device(ctx->device); ++ ++#if defined(GGML_USE_HIP) ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } ++ ggml_hip_mgmt_release(); ++ } ++#else ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->id.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } ++ ggml_nvml_release(); ++ } ++#endif + CUDA_CHECK(cudaMemGetInfo(free, total)); + } + +@@ -3243,6 +3289,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend + return GGML_BACKEND_DEVICE_TYPE_GPU; + } + ++#define GGML_HIP_NAME "HIP" + static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_cuda_device_get_name(dev); + props->description = ggml_backend_cuda_device_get_description(dev); +@@ -3253,6 +3300,23 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back + // If you need the memory data, call ggml_backend_dev_memory() explicitly. + props->memory_total = props->memory_free = 0; + ++ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ++#if defined(GGML_USE_HIP) ++ int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; ++ props->compute_major = cc / 0x100; ++ props->compute_minor = cc - (props->compute_major * 0x100); ++#else ++ props->compute_major = ctx->major; ++ props->compute_minor = ctx->minor; ++#endif ++ props->driver_major = ctx->driver_major; ++ props->driver_minor = ctx->driver_minor; ++ props->integrated = ctx->integrated; ++ props->pci_bus_id = ctx->pci_bus_id; ++ props->pci_device_id = ctx->pci_device_id; ++ props->pci_domain_id = ctx->pci_domain_id; ++ props->library = GGML_CUDA_NAME; ++ + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; + #ifdef GGML_CUDA_NO_PEER_COPY + bool events = false; +@@ -3843,6 +3907,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; ++ int driverVersion = 0; ++ CUDA_CHECK(cudaDriverGetVersion(&driverVersion)); + + for (int i = 0; i < ggml_cuda_info().device_count; i++) { + ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; +@@ -3853,7 +3919,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); + dev_ctx->description = prop.name; + dev_ctx->id = ggml_cuda_parse_uuid(prop, i); +- ++ dev_ctx->major = prop.major; ++ dev_ctx->minor = prop.minor; ++ dev_ctx->driver_major = driverVersion / 1000; ++ dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; ++ dev_ctx->integrated = prop.integrated; ++ dev_ctx->pci_bus_id = prop.pciBusID; ++ dev_ctx->pci_device_id = prop.pciDeviceID; ++ dev_ctx->pci_domain_id = prop.pciDomainID; + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cuda_device_interface, + /* .reg = */ ®, +diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h +index cf22e60d2..957a795f2 100644 +--- a/ggml/src/ggml-cuda/vendors/hip.h ++++ b/ggml/src/ggml-cuda/vendors/hip.h +@@ -42,6 +42,7 @@ + #define cudaDeviceProp hipDeviceProp_t + #define cudaDeviceReset hipDeviceReset + #define cudaDeviceSynchronize hipDeviceSynchronize ++#define cudaDriverGetVersion hipDriverGetVersion + #define cudaError_t hipError_t + #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled + #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h +index 19a7adb2d..b9b102a5e 100644 +--- a/ggml/src/ggml-impl.h ++++ b/ggml/src/ggml-impl.h +@@ -602,6 +602,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx + return true; + } + ++// Management libraries for fetching more accurate free VRAM data ++GGML_API int ggml_nvml_init(); ++GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); ++GGML_API void ggml_nvml_release(); ++GGML_API int ggml_hip_mgmt_init(); ++GGML_API int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total); ++GGML_API void ggml_hip_mgmt_release(); ++ + #ifdef __cplusplus + } + #endif +diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m +index e4c31268f..ec6b385ba 100644 +--- a/ggml/src/ggml-metal/ggml-metal.m ++++ b/ggml/src/ggml-metal/ggml-metal.m +@@ -6523,12 +6523,14 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen + GGML_UNUSED(dev); + } + ++#define GGML_METAL_NAME "Metal" + static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); + props->id = "0"; + props->type = ggml_backend_metal_device_get_type(dev); + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); ++ props->library = GGML_METAL_NAME; + props->caps = (struct ggml_backend_dev_caps) { + /* .async = */ false, + /* .host_buffer = */ false, +diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp +new file mode 100644 +index 000000000..8ef19b8cf +--- /dev/null ++++ b/ggml/src/mem_hip.cpp +@@ -0,0 +1,449 @@ ++#include "ggml.h" ++ ++#ifdef _WIN32 ++// AMD Device Library eXtra (ADLX) ++// ++// https://github.com/GPUOpen-LibrariesAndSDKs/ADLX ++// ++// This Windows-only library provides accurate VRAM reporting for AMD GPUs. ++// The runtime DLL is installed with every AMD Driver on Windows, however ++// the SDK isn't a part of the HIP SDK packaging. As such, we avoid including ++// the headers from the SDK to simplify building from source. ++// ++// ADLX relies heavily on function pointer tables. ++// Only the minimal set of types are defined below to facilitate ++// finding the target AMD GPU(s) and querying their current VRAM usage ++// Unused function parameters are commented out to avoid unnecessary type ++// definitions. ++ ++#include "ggml-impl.h" ++#include ++#include ++ ++#define WIN32_LEAN_AND_MEAN ++#ifndef NOMINMAX ++# define NOMINMAX ++#endif ++#include ++ ++namespace fs = std::filesystem; ++ ++#include ++#include ++ ++// Begin minimal ADLX definitions - derived from tag v1.0 (Dec 2022) ++typedef uint64_t adlx_uint64; ++typedef uint32_t adlx_uint32; ++typedef int32_t adlx_int32; ++typedef adlx_int32 adlx_int; ++typedef adlx_uint32 adlx_uint; ++typedef long adlx_long; ++typedef uint8_t adlx_uint8; ++typedef enum ++{ ++ ADLX_OK = 0, /**< @ENG_START_DOX This result indicates success. @ENG_END_DOX */ ++ ADLX_ALREADY_ENABLED, /**< @ENG_START_DOX This result indicates that the asked action is already enabled. @ENG_END_DOX */ ++ ADLX_ALREADY_INITIALIZED, /**< @ENG_START_DOX This result indicates that ADLX has a unspecified type of initialization. @ENG_END_DOX */ ++ ADLX_FAIL, /**< @ENG_START_DOX This result indicates an unspecified failure. @ENG_END_DOX */ ++ ADLX_INVALID_ARGS, /**< @ENG_START_DOX This result indicates that the arguments are invalid. @ENG_END_DOX */ ++ ADLX_BAD_VER, /**< @ENG_START_DOX This result indicates that the asked version is incompatible with the current version. @ENG_END_DOX */ ++ ADLX_UNKNOWN_INTERFACE, /**< @ENG_START_DOX This result indicates that an unknown interface was asked. @ENG_END_DOX */ ++ ADLX_TERMINATED, /**< @ENG_START_DOX This result indicates that the calls were made in an interface after ADLX was terminated. @ENG_END_DOX */ ++ ADLX_ADL_INIT_ERROR, /**< @ENG_START_DOX This result indicates that the ADL initialization failed. @ENG_END_DOX */ ++ ADLX_NOT_FOUND, /**< @ENG_START_DOX This result indicates that the item is not found. @ENG_END_DOX */ ++ ADLX_INVALID_OBJECT, /**< @ENG_START_DOX This result indicates that the method was called into an invalid object. @ENG_END_DOX */ ++ ADLX_ORPHAN_OBJECTS, /**< @ENG_START_DOX This result indicates that ADLX was terminated with outstanding ADLX objects. Any interface obtained from ADLX points to invalid memory and calls in their methods will result in unexpected behavior. @ENG_END_DOX */ ++ ADLX_NOT_SUPPORTED, /**< @ENG_START_DOX This result indicates that the asked feature is not supported. @ENG_END_DOX */ ++ ADLX_PENDING_OPERATION, /**< @ENG_START_DOX This result indicates a failure due to an operation currently in progress. @ENG_END_DOX */ ++ ADLX_GPU_INACTIVE /**< @ENG_START_DOX This result indicates that the GPU is inactive. @ENG_END_DOX */ ++} ADLX_RESULT; ++#define ADLX_SUCCEEDED(x) (ADLX_OK == (x) || ADLX_ALREADY_ENABLED == (x) || ADLX_ALREADY_INITIALIZED == (x)) ++#define ADLX_FAILED(x) (ADLX_OK != (x) && ADLX_ALREADY_ENABLED != (x) && ADLX_ALREADY_INITIALIZED != (x)) ++#define ADLX_VER_MAJOR 1 ++#define ADLX_VER_MINOR 0 ++#define ADLX_VER_RELEASE 5 ++#define ADLX_VER_BUILD_NUM 30 ++#define ADLX_MAKE_FULL_VER(VERSION_MAJOR, VERSION_MINOR, VERSION_RELEASE, VERSION_BUILD_NUM) ( ((adlx_uint64)(VERSION_MAJOR) << 48ull) | ((adlx_uint64)(VERSION_MINOR) << 32ull) | ((adlx_uint64)(VERSION_RELEASE) << 16ull) | (adlx_uint64)(VERSION_BUILD_NUM)) ++#define ADLX_FULL_VERSION ADLX_MAKE_FULL_VER(ADLX_VER_MAJOR, ADLX_VER_MINOR, ADLX_VER_RELEASE, ADLX_VER_BUILD_NUM) ++#define ADLX_CORE_LINK __declspec(dllexport) ++#define ADLX_STD_CALL __stdcall ++#define ADLX_CDECL_CALL __cdecl ++#define ADLX_FAST_CALL __fastcall ++#define ADLX_INLINE __inline ++#define ADLX_FORCEINLINE __forceinline ++#define ADLX_NO_VTABLE __declspec(novtable) ++ ++#if defined(__cplusplus) ++typedef bool adlx_bool; ++#else ++typedef adlx_uint8 adlx_bool; ++#define true 1 ++#define false 0 ++#endif ++ ++typedef struct IADLXSystem IADLXSystem; ++typedef struct IADLXGPUList IADLXGPUList; ++typedef struct IADLXGPU IADLXGPU; ++typedef struct IADLXInterface IADLXInterface; ++typedef struct IADLXPerformanceMonitoringServices IADLXPerformanceMonitoringServices; ++typedef struct IADLXGPUMetrics IADLXGPUMetrics; ++typedef struct IADLXGPUMetricsSupport IADLXGPUMetricsSupport; ++ ++typedef struct IADLXSystemVtbl ++{ ++ // IADLXSystem interface ++ ADLX_RESULT (ADLX_STD_CALL *GetHybridGraphicsType)(/* IADLXSystem* pThis, ADLX_HG_TYPE* hgType */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUs)(IADLXSystem* pThis, IADLXGPUList** ppGPUs); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXSystem* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ADLX_RESULT (ADLX_STD_CALL *GetDisplaysServices)(/* IADLXSystem* pThis, IADLXDisplayServices** ppDispServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetDesktopsServices)(/* IADLXSystem* pThis, IADLXDesktopServices** ppDeskServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUsChangedHandling)(/* IADLXSystem* pThis, IADLXGPUsChangedHandling** ppGPUsChangedHandling */); ++ ADLX_RESULT (ADLX_STD_CALL *EnableLog)(/* IADLXSystem* pThis, ADLX_LOG_DESTINATION mode, ADLX_LOG_SEVERITY severity, IADLXLog* pLogger, const wchar_t* fileName */); ++ ADLX_RESULT (ADLX_STD_CALL *Get3DSettingsServices)(/* IADLXSystem* pThis, IADLX3DSettingsServices** pp3DSettingsServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUTuningServices)(/* IADLXSystem* pThis, IADLXGPUTuningServices** ppGPUTuningServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetPerformanceMonitoringServices)(IADLXSystem* pThis, IADLXPerformanceMonitoringServices** ppPerformanceMonitoringServices); // Used ++ ADLX_RESULT (ADLX_STD_CALL *TotalSystemRAM)(/* IADLXSystem* pThis, adlx_uint* ramMB */); ++ ADLX_RESULT (ADLX_STD_CALL *GetI2C)(/* IADLXSystem* pThis, IADLXGPU* pGPU, IADLXI2C** ppI2C */); ++} IADLXSystemVtbl; ++struct IADLXSystem { const IADLXSystemVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPU* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXGPU* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPU* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPU ++ ADLX_RESULT (ADLX_STD_CALL *VendorId)(/* IADLXGPU* pThis, const char** vendorId */); ++ ADLX_RESULT (ADLX_STD_CALL *ASICFamilyType)(/* IADLXGPU* pThis, ADLX_ASIC_FAMILY_TYPE* asicFamilyType */); ++ ADLX_RESULT (ADLX_STD_CALL *Type)(/* IADLXGPU* pThis, ADLX_GPU_TYPE* gpuType */); ++ ADLX_RESULT (ADLX_STD_CALL *IsExternal)(/* IADLXGPU* pThis, adlx_bool* isExternal */); ++ ADLX_RESULT (ADLX_STD_CALL *Name)(/* IADLXGPU* pThis, const char** gpuName */); ++ ADLX_RESULT (ADLX_STD_CALL *DriverPath)(/* IADLXGPU* pThis, const char** driverPath */); ++ ADLX_RESULT (ADLX_STD_CALL *PNPString)(/* IADLXGPU* pThis, const char** pnpString */); ++ ADLX_RESULT (ADLX_STD_CALL *HasDesktops)(/* IADLXGPU* pThis, adlx_bool* hasDesktops */); ++ ADLX_RESULT (ADLX_STD_CALL *TotalVRAM)(IADLXGPU* pThis, adlx_uint* vramMB); // Used ++ ADLX_RESULT (ADLX_STD_CALL *VRAMType)(/* IADLXGPU* pThis, const char** type */); ++ ADLX_RESULT (ADLX_STD_CALL *BIOSInfo)(/* IADLXGPU* pThis, const char** partNumber, const char** version, const char** date */); ++ ADLX_RESULT (ADLX_STD_CALL *DeviceId)(/* IADLXGPU* pThis, const char** deviceId */); ++ ADLX_RESULT (ADLX_STD_CALL *RevisionId)(/* IADLXGPU* pThis, const char** revisionId */); ++ ADLX_RESULT (ADLX_STD_CALL *SubSystemId)(/* IADLXGPU* pThis, const char** subSystemId */); ++ ADLX_RESULT (ADLX_STD_CALL *SubSystemVendorId)(/* IADLXGPU* pThis, const char** subSystemVendorId */); ++ ADLX_RESULT (ADLX_STD_CALL *UniqueId)(IADLXGPU* pThis, adlx_int* uniqueId); // Used ++} IADLXGPUVtbl; ++struct IADLXGPU { const IADLXGPUVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUListVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPUList* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXGPUList* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPUList* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXList ++ adlx_uint (ADLX_STD_CALL *Size)(/* IADLXGPUList* pThis */); ++ adlx_uint8 (ADLX_STD_CALL *Empty)(/* IADLXGPUList* pThis */); ++ adlx_uint (ADLX_STD_CALL *Begin)(IADLXGPUList* pThis); // Used ++ adlx_uint (ADLX_STD_CALL *End)(IADLXGPUList* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *At)(/* IADLXGPUList* pThis, const adlx_uint location, IADLXInterface** ppItem */); ++ ADLX_RESULT (ADLX_STD_CALL *Clear)(/* IADLXGPUList* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *Remove_Back)(/* IADLXGPUList* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *Add_Back)(/* IADLXGPUList* pThis, IADLXInterface* pItem */); ++ ++ //IADLXGPUList ++ ADLX_RESULT (ADLX_STD_CALL *At_GPUList)(IADLXGPUList* pThis, const adlx_uint location, IADLXGPU** ppItem); // Used ++ ADLX_RESULT (ADLX_STD_CALL *Add_Back_GPUList)(/* IADLXGPUList* pThis, IADLXGPU* pItem */); ++ ++} IADLXGPUListVtbl; ++struct IADLXGPUList { const IADLXGPUListVtbl *pVtbl; }; ++ ++typedef struct IADLXPerformanceMonitoringServicesVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXPerformanceMonitoringServices* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXPerformanceMonitoringServices* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXPerformanceMonitoringServices* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXPerformanceMonitoringServices ++ ADLX_RESULT (ADLX_STD_CALL *GetSamplingIntervalRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); ++ ADLX_RESULT (ADLX_STD_CALL *SetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int intervalMs */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* intervalMs */); ++ ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySizeRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); ++ ADLX_RESULT (ADLX_STD_CALL *SetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *ClearPerformanceMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *StartPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *StopPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *GetAllMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXAllMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, adlx_int startMs, adlx_int stopMs, IADLXGPUMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSystemMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXSystemMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetFPSHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXFPSList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentAllMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXAllMetrics** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetrics** ppMetrics); // Used ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetrics** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentFPS)(/* IADLXPerformanceMonitoringServices* pThis, IADLXFPS** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSupportedGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetricsSupport** ppMetricsSupported); // Used ++ ADLX_RESULT (ADLX_STD_CALL *GetSupportedSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetricsSupport** ppMetricsSupported */); ++}IADLXPerformanceMonitoringServicesVtbl; ++struct IADLXPerformanceMonitoringServices { const IADLXPerformanceMonitoringServicesVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUMetricsSupportVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetricsSupport* pThis */); ++ adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetricsSupport* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetricsSupport* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPUMetricsSupport ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUUsage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAMClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUHotspotTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTotalBoardPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUFanSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAM)(IADLXGPUMetricsSupport* pThis, adlx_bool* supported); // Used ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVoltage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUUsageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUHotspotTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUFanSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVoltageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUTotalBoardPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++} IADLXGPUMetricsSupportVtbl; ++struct IADLXGPUMetricsSupport { const IADLXGPUMetricsSupportVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUMetricsVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetrics* pThis */); ++ adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetrics* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetrics* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPUMetrics ++ ADLX_RESULT (ADLX_STD_CALL* TimeStamp)(/* IADLXGPUMetrics* pThis, adlx_int64* ms */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUUsage)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUVRAMClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUHotspotTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUTotalBoardPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUFanSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUVRAM)(IADLXGPUMetrics* pThis, adlx_int* data); // Used ++ ADLX_RESULT (ADLX_STD_CALL* GPUVoltage)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++} IADLXGPUMetricsVtbl; ++struct IADLXGPUMetrics { const IADLXGPUMetricsVtbl *pVtbl; }; ++ ++struct { ++ void *handle; ++ ADLX_RESULT (*ADLXInitialize)(adlx_uint64 version, IADLXSystem** ppSystem); ++ ADLX_RESULT (*ADLXInitializeWithIncompatibleDriver)(adlx_uint64 version, IADLXSystem** ppSystem); ++ ADLX_RESULT (*ADLXQueryVersion)(const char** version); ++ ADLX_RESULT (*ADLXTerminate)(); ++ IADLXSystem *sys; ++} adlx { NULL, NULL, NULL, NULL, NULL, NULL }; ++static std::mutex ggml_adlx_lock; ++ ++extern "C" { ++ ++int ggml_hip_mgmt_init() { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle != NULL) { ++ // Already initialized ++ return 0; ++ } ++ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); ++ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); ++ fs::path libPath = fs::path("\\Windows") / fs::path("System32") / fs::path("amdadlx64.dll"); ++ ++ adlx.handle = (void*)LoadLibraryW(libPath.wstring().c_str()); ++ if (adlx.handle == NULL) { ++ return ADLX_NOT_FOUND; ++ } ++ ++ adlx.ADLXInitialize = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitialize"); ++ adlx.ADLXInitializeWithIncompatibleDriver = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitializeWithIncompatibleDriver"); ++ adlx.ADLXTerminate = (ADLX_RESULT (*)()) GetProcAddress((HMODULE)(adlx.handle), "ADLXTerminate"); ++ adlx.ADLXQueryVersion = (ADLX_RESULT (*)(const char **version)) GetProcAddress((HMODULE)(adlx.handle), "ADLXQueryVersion"); ++ if (adlx.ADLXInitialize == NULL || adlx.ADLXInitializeWithIncompatibleDriver == NULL || adlx.ADLXTerminate == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in amdadlx64.dll, falling back to hip free memory reporting", __func__); ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++ return ADLX_NOT_FOUND; ++ } ++ ++ SetErrorMode(old_mode); ++ ++ // Aid in troubleshooting... ++ if (adlx.ADLXQueryVersion != NULL) { ++ const char *version = NULL; ++ ADLX_RESULT status = adlx.ADLXQueryVersion(&version); ++ if (ADLX_SUCCEEDED(status)) { ++ GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); ++ } ++ } ++ ++ ADLX_RESULT status = adlx.ADLXInitialize(ADLX_FULL_VERSION, &adlx.sys); ++ if (ADLX_FAILED(status)) { ++ // GGML_LOG_DEBUG("%s failed to initialize ADLX error=%d - attempting with incompatible driver...\n", __func__, status); ++ // Try with the incompatible driver ++ status = adlx.ADLXInitializeWithIncompatibleDriver(ADLX_FULL_VERSION, &adlx.sys); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s failed to initialize ADLX error=%d\n", __func__, status); ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++ adlx.sys = NULL; ++ return status; ++ } ++ // GGML_LOG_DEBUG("%s initialized ADLX with incpomatible driver\n", __func__); ++ } ++ return ADLX_OK; ++} ++ ++void ggml_hip_mgmt_release() { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle == NULL) { ++ // Already free ++ return; ++ } ++ ADLX_RESULT status = adlx.ADLXTerminate(); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s failed to terminate Adlx %d\n", __func__, status); ++ // Unload anyway... ++ } ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++} ++ ++#define adlx_gdm_cleanup \ ++ if (gpuMetricsSupport != NULL) gpuMetricsSupport->pVtbl->Release(gpuMetricsSupport); \ ++ if (gpuMetrics != NULL) gpuMetrics->pVtbl->Release(gpuMetrics); \ ++ if (perfMonitoringServices != NULL) perfMonitoringServices->pVtbl->Release(perfMonitoringServices); \ ++ if (gpus != NULL) gpus->pVtbl->Release(gpus); \ ++ if (gpu != NULL) gpu->pVtbl->Release(gpu) ++ ++int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle == NULL) { ++ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); ++ return ADLX_ADL_INIT_ERROR; ++ } ++ IADLXGPUMetricsSupport *gpuMetricsSupport = NULL; ++ IADLXPerformanceMonitoringServices *perfMonitoringServices = NULL; ++ IADLXGPUList* gpus = NULL; ++ IADLXGPU* gpu = NULL; ++ IADLXGPUMetrics *gpuMetrics = NULL; ++ ADLX_RESULT status; ++ // The "UniqueID" exposed in ADLX is the PCI Bus and Device IDs ++ adlx_int target = (pci_bus_id << 8) | (pci_device_id & 0xff); ++ ++ status = adlx.sys->pVtbl->GetPerformanceMonitoringServices(adlx.sys, &perfMonitoringServices); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetPerformanceMonitoringServices failed %d\n", __func__, status); ++ return status; ++ } ++ ++ status = adlx.sys->pVtbl->GetGPUs(adlx.sys, &gpus); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetGPUs failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ // Get GPU list ++ for (adlx_uint crt = gpus->pVtbl->Begin(gpus); crt != gpus->pVtbl->End(gpus); ++crt) ++ { ++ status = gpus->pVtbl->At_GPUList(gpus, crt, &gpu); ++ if (ADLX_FAILED(status)) ++ { ++ GGML_LOG_INFO("%s %d] At_GPUList failed %d\n", __func__, crt, status); ++ continue; ++ } ++ adlx_int id; ++ status = gpu->pVtbl->UniqueId(gpu, &id); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s %d] UniqueId lookup failed %d\n", __func__, crt, status); ++ gpu->pVtbl->Release(gpu); ++ gpu = NULL; ++ continue; ++ } ++ if (id != target) { ++ GGML_LOG_DEBUG("%s %d] GPU UniqueId: %x does not match target %02x %02x\n", __func__, crt, id, pci_bus_id, pci_device_id); ++ gpu->pVtbl->Release(gpu); ++ gpu = NULL; ++ continue; ++ } ++ // Any failures at this point should cause a fall-back to other APIs ++ status = perfMonitoringServices->pVtbl->GetSupportedGPUMetrics(perfMonitoringServices, gpu, &gpuMetricsSupport); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetSupportedGPUMetrics failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ status = perfMonitoringServices->pVtbl->GetCurrentGPUMetrics(perfMonitoringServices, gpu, &gpuMetrics); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetCurrentGPUMetrics failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_bool supported = false; ++ status = gpuMetricsSupport->pVtbl->IsSupportedGPUVRAM(gpuMetricsSupport, &supported); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s IsSupportedGPUVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_uint totalVRAM = 0; ++ status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s TotalVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_int usedVRAM = 0; ++ status = gpuMetrics->pVtbl->GPUVRAM(gpuMetrics, &usedVRAM); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GPUVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ *total = size_t(totalVRAM) * 1024 * 1024; ++ *free = size_t(totalVRAM-usedVRAM) * 1024 * 1024; ++ ++ adlx_gdm_cleanup; ++ return ADLX_OK; ++ } ++ adlx_gdm_cleanup; ++ return ADLX_NOT_FOUND; ++} ++ ++} // extern "C" ++ ++#else // #ifdef _WIN32 ++ ++extern "C" { ++ ++// TODO Linux implementation of accurate VRAM reporting ++int ggml_hip_mgmt_init() { ++ return -1; ++} ++void ggml_hip_mgmt_release() {} ++int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { ++ return -1; ++} ++ ++} // extern "C" ++ ++#endif // #ifdef _WIN32 +\ No newline at end of file +diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp +new file mode 100644 +index 000000000..aa05e9dc1 +--- /dev/null ++++ b/ggml/src/mem_nvml.cpp +@@ -0,0 +1,172 @@ ++// NVIDIA Management Library (NVML) ++// ++// https://developer.nvidia.com/management-library-nvml ++// ++// This library provides accurate VRAM reporting for NVIDIA GPUs, particularly ++// on Windows, where the cuda library provides inaccurate VRAM usage metrics. The ++// runtime DLL is installed with every driver on Windows, and most Linux ++// systems, and the headers are included in the standard CUDA SDK install. As ++// such, we can include the header here to simplify the code. ++ ++ ++#include "ggml-impl.h" ++#include ++#include ++ ++#ifdef _WIN32 ++# define WIN32_LEAN_AND_MEAN ++# ifndef NOMINMAX ++# define NOMINMAX ++# endif ++# include ++#else ++# include ++# include ++#endif ++ ++namespace fs = std::filesystem; ++ ++// Minimal definitions to avoid including the nvml.h header ++typedef enum nvmlReturn_enum ++{ ++ // cppcheck-suppress * ++ NVML_SUCCESS = 0, //!< The operation was successful ++ NVML_ERROR_UNINITIALIZED = 1, //!< NVML was not first initialized with nvmlInit() ++ NVML_ERROR_INVALID_ARGUMENT = 2, //!< A supplied argument is invalid ++ NVML_ERROR_NOT_SUPPORTED = 3, //!< The requested operation is not available on target device ++ NVML_ERROR_NO_PERMISSION = 4, //!< The current user does not have permission for operation ++ NVML_ERROR_ALREADY_INITIALIZED = 5, //!< Deprecated: Multiple initializations are now allowed through ref counting ++ NVML_ERROR_NOT_FOUND = 6, //!< A query to find an object was unsuccessful ++ NVML_ERROR_INSUFFICIENT_SIZE = 7, //!< An input argument is not large enough ++ NVML_ERROR_INSUFFICIENT_POWER = 8, //!< A device's external power cables are not properly attached ++ NVML_ERROR_DRIVER_NOT_LOADED = 9, //!< NVIDIA driver is not loaded ++ NVML_ERROR_TIMEOUT = 10, //!< User provided timeout passed ++ NVML_ERROR_IRQ_ISSUE = 11, //!< NVIDIA Kernel detected an interrupt issue with a GPU ++ NVML_ERROR_LIBRARY_NOT_FOUND = 12, //!< NVML Shared Library couldn't be found or loaded ++ NVML_ERROR_FUNCTION_NOT_FOUND = 13, //!< Local version of NVML doesn't implement this function ++ NVML_ERROR_CORRUPTED_INFOROM = 14, //!< infoROM is corrupted ++ NVML_ERROR_GPU_IS_LOST = 15, //!< The GPU has fallen off the bus or has otherwise become inaccessible ++ NVML_ERROR_RESET_REQUIRED = 16, //!< The GPU requires a reset before it can be used again ++ NVML_ERROR_OPERATING_SYSTEM = 17, //!< The GPU control device has been blocked by the operating system/cgroups ++ NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18, //!< RM detects a driver/library version mismatch ++ NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use ++ NVML_ERROR_MEMORY = 20, //!< Insufficient memory ++ NVML_ERROR_NO_DATA = 21, //!< No data ++ NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled ++ NVML_ERROR_INSUFFICIENT_RESOURCES = 23, //!< Ran out of critical resources, other than memory ++ NVML_ERROR_FREQ_NOT_SUPPORTED = 24, //!< Ran out of critical resources, other than memory ++ NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25, //!< The provided version is invalid/unsupported ++ NVML_ERROR_DEPRECATED = 26, //!< The requested functionality has been deprecated ++ NVML_ERROR_NOT_READY = 27, //!< The system is not ready for the request ++ NVML_ERROR_GPU_NOT_FOUND = 28, //!< No GPUs were found ++ NVML_ERROR_INVALID_STATE = 29, //!< Resource not in correct state to perform requested operation ++ NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred ++} nvmlReturn_t; ++typedef struct nvmlDevice_st* nvmlDevice_t; ++typedef struct nvmlMemory_st ++{ ++ unsigned long long total; //!< Total physical device memory (in bytes) ++ unsigned long long free; //!< Unallocated device memory (in bytes) ++ unsigned long long used; //!< Sum of Reserved and Allocated device memory (in bytes). ++ //!< Note that the driver/GPU always sets aside a small amount of memory for bookkeeping ++} nvmlMemory_t; ++// end nvml.h definitions ++ ++struct { ++ void *handle; ++ nvmlReturn_t (*nvmlInit_v2)(void); ++ nvmlReturn_t (*nvmlShutdown)(void); ++ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); ++ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); ++} nvml { NULL, NULL, NULL, NULL, NULL }; ++static std::mutex ggml_nvml_lock; ++ ++extern "C" { ++ ++int ggml_nvml_init() { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle != NULL) { ++ // Already initialized ++ return 0; ++ } ++#ifdef _WIN32 ++ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); ++ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); ++ fs::path libPath[2]; ++ const char * programDir = std::getenv("ProgramW6432"); ++ if (programDir == NULL) { ++ libPath[0] = fs::path("Program Files") / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); ++ } else { ++ libPath[0] = fs::path(programDir) / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); ++ } ++ libPath[1] = fs::path("\\Windows") / fs::path("System32") / fs::path("NVML.dll"); ++ ++ for (int i = 0; i < 2; i++) { ++ nvml.handle = (void*)LoadLibraryW(libPath[i].wstring().c_str()); ++ if (nvml.handle != NULL) { ++ break; ++ } ++ } ++ if (nvml.handle == NULL) { ++ return NVML_ERROR_NOT_FOUND; ++ } ++ ++ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlInit_v2"); ++ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); ++ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); ++ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); ++ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); ++ FreeLibrary((HMODULE)(nvml.handle)); ++ nvml.handle = NULL; ++ return NVML_ERROR_NOT_FOUND; ++ } ++ ++ SetErrorMode(old_mode); ++ ++#else ++ // Not currently wired up on Linux ++ return NVML_ERROR_NOT_SUPPORTED; ++#endif ++ int status = nvml.nvmlInit_v2(); ++ return NVML_SUCCESS; ++} ++ ++void ggml_nvml_release() { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle == NULL) { ++ // Already free ++ return; ++ } ++ nvmlReturn_enum status = nvml.nvmlShutdown(); ++ if (status != NVML_SUCCESS) { ++ GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status); ++ } ++#ifdef _WIN32 ++ FreeLibrary((HMODULE)(nvml.handle)); ++ nvml.handle = NULL; ++#else ++ // Not currently wired up on Linux ++#endif ++} ++ ++int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle == NULL) { ++ return NVML_ERROR_UNINITIALIZED; ++ } ++ nvmlDevice_t device; ++ auto status = nvml.nvmlDeviceGetHandleByUUID(uuid, &device); ++ if (status != NVML_SUCCESS) { ++ return status; ++ } ++ nvmlMemory_t memInfo = {0}; ++ status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); ++ if (status == NVML_SUCCESS) { ++ *free = memInfo.free; ++ *total = memInfo.total; ++ } ++ return status; ++} ++ ++} +\ No newline at end of file diff --git a/llama/patches/0026-ggml-Backport-scale-kernel-fixes.patch b/llama/patches/0026-ggml-Backport-scale-kernel-fixes.patch new file mode 100644 index 000000000..651c97ad4 --- /dev/null +++ b/llama/patches/0026-ggml-Backport-scale-kernel-fixes.patch @@ -0,0 +1,57 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jesse Gross +Date: Tue, 23 Sep 2025 15:41:58 -0700 +Subject: [PATCH] ggml: Backport scale kernel fixes + +The GGML scale kernel uses signed 32-bit ints to represent +the number of elements in the tensor. For large images, +mistral-small3.2 overflows this, triggering CUDA errors due +to negative arguments. + +Currently, this can happen when the user passes a large image +to mistral-small3.2. However, with upcoming changes to reserve +CUDA memory, it happens every time mistral-small is loaded as +we reserve using a worst case batch. + +This patch is part of an upstream GGML commit and should be removed +after GGML is updated past 0a1b398 "ggml: add ops for WAN video model +(cuda && cpu) (#15669)". + +Fixes #10388 +--- + ggml/src/ggml-cuda/scale.cu | 19 ++++++++++--------- + 1 file changed, 10 insertions(+), 9 deletions(-) + +diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu +index 2ee9e5889..0ddeff6a1 100644 +--- a/ggml/src/ggml-cuda/scale.cu ++++ b/ggml/src/ggml-cuda/scale.cu +@@ -1,18 +1,19 @@ + #include "scale.cuh" + +-static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { +- const int i = blockDim.x*blockIdx.x + threadIdx.x; ++#define MAX_GRIDDIM_X 0x7FFFFFFF + +- if (i >= k) { +- return; +- } ++static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { ++ int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; ++ int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + +- dst[i] = scale * x[i] + bias; ++ for (int64_t i = tid; i < nelements; i += stride) { ++ dst[i] = scale * x[i] + bias; ++ } + } + +-static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { +- const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; +- scale_f32<<>>(x, dst, scale, bias, k); ++static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { ++ const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; ++ scale_f32<<>>(x, dst, scale, bias, nelements); + } + + void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/llm/memory.go b/llm/memory.go index ee4be7419..6f192b35d 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -30,7 +30,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin // Try to pack into as few GPUs as possible, starting from 1 GPU for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ { gpuSubset := sgl[:numGPUs] - ok, estimatedVRAM := PredictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel) + ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel) if ok { slog.Info("new model will fit in available VRAM across minimum required GPUs, loading", @@ -48,7 +48,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin // - try subsets of GPUs instead of just falling back to 1 or all in a family // Now try all the GPUS (OLLAMA_SCHED_SPREAD is set) - if ok, estimatedVRAM := PredictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok { + if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok { slog.Info("new model will fit in available VRAM, loading", "model", modelPath, "library", sgl[0].Library, @@ -71,7 +71,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s var bestEstimate uint64 var bestFit int for i, gl := range byLibrary { - _, estimatedVRAM := PredictServerFit(gl, f, adapters, projectors, opts, numParallel) + _, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel) if estimatedVRAM > bestEstimate { bestEstimate = estimatedVRAM bestFit = i @@ -81,7 +81,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s } // This algorithm looks for a complete fit to determine if we need to unload other models -func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) { +func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) { // Split up the GPUs by type and try them var estimatedVRAM uint64 for _, gpus := range allGpus.ByLibrary() { @@ -97,6 +97,10 @@ func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj return true, estimatedVRAM } } + + if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory { + return true, estimatedVRAM + } } return false, estimatedVRAM } @@ -191,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin slog.Warn("model missing blk.0 layer size") } + useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) && + (discover.GpuInfoList)(gpus).FlashAttentionSupported() && + f.SupportsFlashAttention() + var kvct string - if envconfig.FlashAttention() && - discover.GetGPUInfo().FlashAttentionSupported() && - f.SupportsFlashAttention() { + if useFlashAttention { requested := strings.ToLower(envconfig.KvCacheType()) - if requested != "" && f.SupportsKVCacheType(requested) { + if f.SupportsKVCacheType(requested) { kvct = requested } } - kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct) + kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention) if len(kv) > 0 { layerSize += kv[0] @@ -225,7 +231,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } // on metal there's no partial offload overhead - if gpus[0].Library == "metal" { + if gpus[0].Library == "Metal" { graphPartialOffload = graphFullOffload } else if len(gpus) > 1 { // multigpu should always use the partial graph size diff --git a/llm/memory_test.go b/llm/memory_test.go index 49851006c..553214b9e 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -12,6 +12,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" ) func TestEstimateGPULayers(t *testing.T) { @@ -55,7 +56,9 @@ func TestEstimateGPULayers(t *testing.T) { // Simple CPU scenario gpus := []discover.GpuInfo{ { - Library: "cpu", + DeviceID: ml.DeviceID{ + Library: "cpu", + }, }, } projectors := []string{} @@ -77,11 +80,15 @@ func TestEstimateGPULayers(t *testing.T) { gpuMinimumMemory := uint64(2048) gpus = []discover.GpuInfo{ { - Library: "cuda", + DeviceID: ml.DeviceID{ + Library: "cuda", + }, MinimumMemory: gpuMinimumMemory, }, { - Library: "cuda", + DeviceID: ml.DeviceID{ + Library: "cuda", + }, MinimumMemory: gpuMinimumMemory, }, } diff --git a/llm/server.go b/llm/server.go index ecdaa90e9..63ad6085c 100644 --- a/llm/server.go +++ b/llm/server.go @@ -66,7 +66,7 @@ func (e filteredEnv) LogValue() slog.Value { type LlamaServer interface { ModelPath() string - Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error + Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error @@ -76,8 +76,11 @@ type LlamaServer interface { Close() error VRAMSize() uint64 // Total VRAM across all GPUs TotalSize() uint64 - VRAMByGPU(gpuID string) uint64 + VRAMByGPU(id ml.DeviceID) uint64 Pid() int + GetPort() int + GetDeviceInfos(ctx context.Context) []ml.DeviceInfo + HasExited() bool } // llmServer is an instance of a runner hosting a single model @@ -148,7 +151,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a var textProcessor model.TextProcessor var err error if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { - textProcessor, err = model.NewTextProcessor(modelPath) + if len(projectors) == 0 { + textProcessor, err = model.NewTextProcessor(modelPath) + } else { + err = errors.New("split vision models aren't supported") + } if err != nil { // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) @@ -161,11 +168,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } } - newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates() - if newEstimates { - slog.Info("enabling new memory estimates") - } - // Verify the requested context size is <= the model training size trainCtx := f.KV().ContextLength() if opts.NumCtx > int(trainCtx) && trainCtx > 0 { @@ -173,6 +175,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a opts.NumCtx = int(trainCtx) } + opts.NumBatch = min(opts.NumBatch, opts.NumCtx) + loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()} defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount() @@ -195,6 +199,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset // that can handle it. fa := envconfig.FlashAttention() + if f.FlashAttention() { + slog.Info("model wants flash attention") + fa = true + } + if fa && !gpus.FlashAttentionSupported() { slog.Warn("flash attention enabled but not supported by gpu") fa = false @@ -213,7 +222,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a // Flash Attention also supports kv cache quantization // Enable if the requested and kv cache type is supported by the model - if kvct != "" && f.SupportsKVCacheType(kvct) { + if f.SupportsKVCacheType(kvct) { loadRequest.KvCacheType = kvct } else { slog.Warn("kv cache type not supported by model", "type", kvct) @@ -325,6 +334,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a if gpu.DependencyPath != nil { slog.Debug("adding gpu dependency paths", "paths", gpu.DependencyPath) libraryPaths = append(gpu.DependencyPath, libraryPaths...) + ggmlPaths = append(ggmlPaths, gpu.DependencyPath...) } } @@ -355,23 +365,24 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) - envWorkarounds := [][2]string{} - for _, gpu := range gpus { - envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...) - } + // Always filter down the set of GPUs in case there are any unsupported devices that might crash + envWorkarounds := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) // Update or add the path variable with our adjusted version pathNeeded := true + envWorkaroundDone := make([]bool, len(envWorkarounds)) for i := range s.cmd.Env { cmp := strings.SplitN(s.cmd.Env[i], "=", 2) if strings.EqualFold(cmp[0], pathEnv) { s.cmd.Env[i] = pathEnv + "=" + pathEnvVal pathNeeded = false } else if len(envWorkarounds) != 0 { - for _, kv := range envWorkarounds { - if strings.EqualFold(cmp[0], kv[0]) { - s.cmd.Env[i] = kv[0] + "=" + kv[1] + for j, kv := range envWorkarounds { + tmp := strings.SplitN(kv, "=", 2) + if strings.EqualFold(cmp[0], tmp[0]) { + s.cmd.Env[i] = kv + envWorkaroundDone[j] = true } } } @@ -379,6 +390,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a if pathNeeded { s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) } + for i, done := range envWorkaroundDone { + if !done { + s.cmd.Env = append(s.cmd.Env, envWorkarounds[i]) + } + } slog.Info("starting runner", "cmd", s.cmd) slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) @@ -416,7 +432,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } }() - if newEstimates { + if textProcessor != nil { return &ollamaServer{llmServer: s}, nil } else { return &llamaServer{llmServer: s, ggml: f}, nil @@ -480,7 +496,7 @@ type LoadResponse struct { var ErrLoadRequiredFull = errors.New("unable to load full model on GPU") -func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { +func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) { systemInfo := discover.GetSystemInfo() systemTotalMemory := systemInfo.System.TotalMemory systemFreeMemory := systemInfo.System.FreeMemory @@ -492,7 +508,8 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi if !requireFull { g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel) } else { - return ErrLoadRequiredFull + slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate) + return nil, ErrLoadRequiredFull } } @@ -501,13 +518,13 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi if len(gpus) > 1 || gpus[0].Library != "cpu" { switch { - case gpus[0].Library == "metal" && s.estimate.VRAMSize > systemInfo.System.TotalMemory: + case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.System.TotalMemory: // disable partial offloading when model is greater than total system memory as this // can lead to locking up the system s.options.NumGPU = 0 - case gpus[0].Library != "metal" && s.estimate.Layers == 0: + case gpus[0].Library != "Metal" && s.estimate.Layers == 0: // Don't bother loading into the GPU if no layers can fit - gpus = discover.GetCPUInfo() + gpus = discover.GpuInfoList{discover.GetCPUInfo()} case s.options.NumGPU < 0 && s.estimate.Layers > 0 && gpus[0].Library != "cpu": s.options.NumGPU = s.estimate.Layers } @@ -520,14 +537,10 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi available := systemInfo.System.FreeMemory + systemInfo.System.FreeSwap if systemMemoryRequired > available { slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.System.TotalMemory), "free", format.HumanBytes2(systemInfo.System.FreeMemory), "swap", format.HumanBytes2(systemInfo.System.FreeSwap)) - return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) + return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) } } - if requireFull && len(gpus) == 1 && gpus[0].Library == "cpu" && s.estimate.TotalSize > gpus[0].FreeMemory { - return ErrLoadRequiredFull - } - slog.Info("offload", "", s.estimate) s.gpus = gpus @@ -539,7 +552,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // mmap has issues with partial offloading on metal for _, g := range gpus { - if g.Library == "metal" && + if g.Library == "Metal" && uint64(s.options.NumGPU) > 0 && uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 { s.options.UseMMap = new(bool) @@ -550,7 +563,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache - if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && s.options.UseMMap == nil) || + if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { @@ -559,12 +572,12 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi } if err := s.waitUntilRunnerLaunched(ctx); err != nil { - return err + return nil, err } resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit) if err != nil { - return err + return nil, err } // On the Ollama engine, we can print out a summary of the memory allocations. @@ -575,16 +588,16 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi if !resp.Success { slog.Warn("failed to allocate memory for model", "memory", resp.Memory) - return errors.New("failed to allocate memory for model") + return nil, errors.New("failed to allocate memory for model") } // The llama engine does its memory allocations together with model loading, so we // need to wait until it is done to ensure that we have accurate memory data before // loading the next model if s.textProcessor == nil { - return s.WaitUntilRunning(ctx) + return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx) } else { - return nil + return uniqueDeviceIDs(s.loadRequest.GPULayers), nil } } @@ -597,7 +610,7 @@ func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.Gpu gpuLayers := make(ml.GPULayersList, len(gpus)) for i := range gpuLayers { - gpuLayers[i].ID = gpus[i].ID + gpuLayers[i].DeviceID = gpus[i].DeviceID } var sum float32 @@ -645,7 +658,9 @@ func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.Gpu // // This process is repeated for higher levels of loading the model (fit, allocate, commit). The earlier levels are quicker, // allowing for faster iteration, but may return less information. -func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { +// +// Returns the list of GPU IDs that were used in the final allocation on success +func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) { var success bool defer func() { if !success { @@ -666,8 +681,12 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ if !(len(gpus) == 1 && gpus[0].Library == "cpu") { for _, gpu := range gpus { - slog.Info("gpu memory", "id", gpu.ID, - "available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory), + available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory + if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory { + available = 0 + } + slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library, + "available", format.HumanBytes2(available), "free", format.HumanBytes2(gpu.FreeMemory), "minimum", format.HumanBytes2(gpu.MinimumMemory), "overhead", format.HumanBytes2(envconfig.GpuOverhead())) @@ -679,11 +698,11 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) if err != nil { - return err + return nil, err } if err := s.waitUntilRunnerLaunched(ctx); err != nil { - return err + return nil, err } nextOperation: @@ -693,7 +712,7 @@ nextOperation: s.loadRequest.GPULayers = gpuLayers resp, err := s.initModel(ctx, s.loadRequest, operation) if err != nil { - return err + return nil, err } resp.Memory.Log(slog.LevelDebug) @@ -705,7 +724,7 @@ nextOperation: for { newGPULayers, err := s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) if err != nil { - return err + return nil, err } slog.Debug("new layout created", "layers", newGPULayers) @@ -739,7 +758,7 @@ nextOperation: newGPULayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) s.options.NumGPU = -1 if err != nil { - return err + return nil, err } slog.Debug("new layout created", "layers", newGPULayers) @@ -747,7 +766,7 @@ nextOperation: s.loadRequest.GPULayers = newGPULayers resp, err = s.initModel(ctx, s.loadRequest, operation) if err != nil { - return err + return nil, err } resp.Memory.Log(slog.LevelDebug) @@ -756,7 +775,7 @@ nextOperation: if resp.Success { verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff) if err != nil { - return err + return nil, err } slog.Debug("verifying layout", "layers", verifyGPULayers) @@ -781,7 +800,7 @@ nextOperation: } if s.options.NumGPU >= 0 { - return fmt.Errorf("memory layout cannot be allocated with num_gpu = %v", s.options.NumGPU) + return nil, fmt.Errorf("memory layout cannot be allocated with num_gpu = %v", s.options.NumGPU) } // Memory allocation failed even though we created a layout that we thought should @@ -791,7 +810,7 @@ nextOperation: // space. if backoff > 1 { slog.Warn("memory layout cannot be allocated", "memory", resp.Memory) - return errors.New("memory layout cannot be allocated") + return nil, errors.New("memory layout cannot be allocated") } else if backoff == 0 { backoff = 0.01 } else { @@ -806,7 +825,7 @@ nextOperation: s.loadRequest.GPULayers = gpuLayers resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit) if err != nil { - return err + return nil, err } success = resp.Success @@ -814,10 +833,27 @@ nextOperation: if !success { slog.Warn("failed to commit memory for model", "memory", resp.Memory) - return errors.New("failed to commit memory for model") + return nil, errors.New("failed to commit memory for model") } - return nil + return uniqueDeviceIDs(gpuLayers), nil +} + +func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID { + devices := []ml.DeviceID{} + for _, layer := range gpuLayers { + new := true + for _, ID := range devices { + if layer.DeviceID == ID { + new = false + break + } + } + if new { + devices = append(devices, layer.DeviceID) + } + } + return devices } // createLayout uses the current best view of memory requirements and creates a layout of model layers on GPUs. @@ -836,20 +872,20 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d if memory == nil { memory = &ml.BackendMemory{CPU: ml.DeviceMemory{ - Weights: make([]ml.Memory, s.totalLayers), - Cache: make([]ml.Memory, s.totalLayers), + Weights: make([]uint64, s.totalLayers), + Cache: make([]uint64, s.totalLayers), }} } layers := make([]uint64, len(memory.CPU.Weights)) for i := range layers { for j := range memory.GPUs { - layers[i] += memory.GPUs[j].Weights[i].Size - layers[i] += memory.GPUs[j].Cache[i].Size + layers[i] += memory.GPUs[j].Weights[i] + layers[i] += memory.GPUs[j].Cache[i] } - layers[i] += memory.CPU.Weights[i].Size - layers[i] += memory.CPU.Cache[i].Size - slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i])) + layers[i] += memory.CPU.Weights[i] + layers[i] += memory.CPU.Cache[i] + logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i])) } gpuLayers := ml.GPULayersList{} @@ -862,23 +898,23 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d for i := range gl { found := false for j := range memory.GPUs { - if gl[i].ID == memory.GPUs[j].ID { - if memory.GPUs[j].Graph.Size != 0 { + if gl[i].DeviceID == memory.GPUs[j].DeviceID { + if memory.GPUs[j].Graph != 0 { lastUsedGPU = i } - reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory + envconfig.GpuOverhead() + memory.GPUs[j].Graph.Size + reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory + envconfig.GpuOverhead() + memory.GPUs[j].Graph if gl[i].FreeMemory > reserved { gl[i].FreeMemory -= reserved } else { gl[i].FreeMemory = 0 } - slog.Debug("available gpu", "id", gl[i].ID, + slog.Debug("available gpu", "id", gl[i].ID, "library", gl[i].Library, "available layer vram", format.HumanBytes2(gl[i].FreeMemory), "backoff", fmt.Sprintf("%.2f", backoff), "minimum", format.HumanBytes2(gl[i].MinimumMemory), "overhead", format.HumanBytes2(envconfig.GpuOverhead()), - "graph", format.HumanBytes2(memory.GPUs[j].Graph.Size)) + "graph", format.HumanBytes2(memory.GPUs[j].Graph)) found = true break @@ -897,12 +933,12 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d } // These sizes will only increase as we go through additional iterations and get additional information. - cpuSize := memory.InputWeights.Size + memory.CPU.Graph.Size + cpuSize := memory.InputWeights + memory.CPU.Graph var vramSize uint64 for _, gl := range gpuLayers { for _, gpu := range memory.GPUs { - if gl.ID == gpu.ID { - vramSize += gpu.Graph.Size + if gl.DeviceID == gpu.DeviceID { + vramSize += gpu.Graph break } } @@ -1022,7 +1058,7 @@ func findBestFit(layers []uint64, gpus discover.GpuInfoList, requestedLayers int // greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, requestedLayers int) (gpuLayers ml.GPULayersList) { device := len(gpus) - 1 - gpuLayers = ml.GPULayersList{{ID: gpus[device].ID}} + gpuLayers = ml.GPULayersList{{DeviceID: gpus[device].DeviceID}} freeSpace := uint64(float32(gpus[device].FreeMemory) * capacity) for i := len(layers) - 1; i >= 0; i-- { if requestedLayers >= 0 && len(layers)-1-i >= requestedLayers { @@ -1040,7 +1076,7 @@ func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, req if device < 0 { return gpuLayers } - gpuLayers = append(ml.GPULayersList{{ID: gpus[device].ID}}, gpuLayers...) + gpuLayers = append(ml.GPULayersList{{DeviceID: gpus[device].DeviceID}}, gpuLayers...) freeSpace = uint64(float32(gpus[device].FreeMemory) * capacity) } } @@ -1295,6 +1331,17 @@ func (s *llmServer) Pid() int { return -1 } +func (s *llmServer) GetPort() int { + return s.port +} + +func (s *llmServer) HasExited() bool { + if s.cmd != nil && s.cmd.ProcessState != nil && s.cmd.ProcessState.ExitCode() >= 0 { + return true + } + return false +} + var grammarJSON = ` root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws @@ -1369,7 +1416,7 @@ type CompletionResponse struct { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format)) - slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt) + logutil.Trace("completion request", "prompt", req.Prompt) if len(req.Format) > 0 { switch string(req.Format) { @@ -1535,7 +1582,7 @@ type EmbeddingResponse struct { } func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { - slog.Log(ctx, logutil.LevelTrace, "embedding request", "input", input) + logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { @@ -1687,9 +1734,9 @@ func (s *llamaServer) TotalSize() uint64 { return s.estimate.TotalSize } -func (s *llamaServer) VRAMByGPU(gpuID string) uint64 { +func (s *llamaServer) VRAMByGPU(id ml.DeviceID) uint64 { for i, gpu := range s.gpus { - if gpu.ID == gpuID { + if gpu.DeviceID == id { if i < len(s.estimate.GPUSizes) { return s.estimate.GPUSizes[i] } @@ -1698,6 +1745,11 @@ func (s *llamaServer) VRAMByGPU(gpuID string) uint64 { return 0 } +func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + slog.Debug("llamarunner free vram reporting not supported") + return nil +} + func (s *ollamaServer) VRAMSize() uint64 { if s.mem == nil { return 0 @@ -1706,21 +1758,21 @@ func (s *ollamaServer) VRAMSize() uint64 { var mem uint64 for _, g := range s.mem.GPUs { - mem += g.Allocated() + mem += g.Size() } // Some elements are always on CPU. However, if we have allocated all layers // on the GPU then include the CPU components as well, to represent complete offloading. noCPULayers := true for i := range s.mem.CPU.Weights { - if s.mem.CPU.Weights[i].Size != 0 || s.mem.CPU.Cache[i].Size != 0 { + if s.mem.CPU.Weights[i] != 0 || s.mem.CPU.Cache[i] != 0 { noCPULayers = false break } } if noCPULayers { - mem += s.mem.InputWeights.Size - mem += s.mem.CPU.Graph.Size + mem += s.mem.InputWeights + mem += s.mem.CPU.Graph } return mem @@ -1731,25 +1783,37 @@ func (s *ollamaServer) TotalSize() uint64 { return 0 } - mem := s.mem.InputWeights.Size - mem += s.mem.CPU.Allocated() + mem := s.mem.InputWeights + mem += s.mem.CPU.Size() for _, g := range s.mem.GPUs { - mem += g.Allocated() + mem += g.Size() } return mem } -func (s *ollamaServer) VRAMByGPU(gpuID string) uint64 { +func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 { if s.mem == nil { return 0 } for _, g := range s.mem.GPUs { - if g.ID == gpuID { - return g.Allocated() + if g.DeviceID == id { + return g.Size() } } return 0 } + +func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + devices, err := discover.GetDevicesFromRunner(ctx, s) + if err != nil { + if s.cmd != nil && s.cmd.ProcessState == nil { + // Still running but hit an error, log + slog.Debug("failure refreshing GPU information", "error", err) + } + // else no longer running so suppress logging as a failure is expected + } + return devices +} diff --git a/llm/server_test.go b/llm/server_test.go index 4eed82bce..f1e67c34e 100644 --- a/llm/server_test.go +++ b/llm/server_test.go @@ -16,8 +16,8 @@ import ( func TestLLMServerFitGPU(t *testing.T) { type gpu struct { - library string - free int + id ml.DeviceID + free int } tests := []struct { @@ -37,91 +37,91 @@ func TestLLMServerFitGPU(t *testing.T) { }, { name: "Full single GPU", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}}, }, { name: "Partial single GPU", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}}, }, { name: "Single GPU with numGPU 1", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 1, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}}, }, { name: "Single GPU with numGPU 0", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 0, expected: ml.GPULayersList{}, }, { name: "Single GPU with numGPU 999", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, numGPU: 999, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2, 3}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}}, }, { name: "Multi GPU fits on one", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}}, }, { name: "Multi GPU split", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}}, }, { name: "Multi GPU partial", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}}, }, { name: "Multi GPU numGPU 1", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}}, }, { name: "Multi GPU numGPU 2", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 2, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}}, }, { name: "Multi GPU numGPU 999", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte}, numGPU: 999, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}, {ID: "gpu0", Layers: []int{2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}}, }, { name: "Multi GPU different libraries", - gpus: []gpu{{library: "cuda", free: 128 * format.MebiByte}, {library: "rocm", free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}}, }, { name: "requireFull", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, numGPU: -1, requireFull: true, @@ -138,8 +138,7 @@ func TestLLMServerFitGPU(t *testing.T) { gpus := make(discover.GpuInfoList, len(tt.gpus)) for i := range tt.gpus { - gpus[i].ID = fmt.Sprintf("gpu%d", i) - gpus[i].Library = tt.gpus[i].library + gpus[i].DeviceID = tt.gpus[i].id gpus[i].FreeMemory = uint64(tt.gpus[i].free) } @@ -155,18 +154,18 @@ func TestLLMServerFitGPU(t *testing.T) { } s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{ - Weights: make([]ml.Memory, s.totalLayers), - Cache: make([]ml.Memory, s.totalLayers), + Weights: make([]uint64, s.totalLayers), + Cache: make([]uint64, s.totalLayers), }, GPUs: make([]ml.DeviceMemory, len(gpus))} for i := range tt.layers { - s.mem.CPU.Weights[i].Size = uint64(tt.layers[i]) + s.mem.CPU.Weights[i] = uint64(tt.layers[i]) } for i := range s.mem.GPUs { - s.mem.GPUs[i].ID = fmt.Sprintf("gpu%d", i) - s.mem.GPUs[i].Weights = make([]ml.Memory, s.totalLayers) - s.mem.GPUs[i].Cache = make([]ml.Memory, s.totalLayers) + s.mem.GPUs[i].DeviceID = gpus[i].DeviceID + s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers) + s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers) } gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0) diff --git a/logutil/logutil.go b/logutil/logutil.go index 406caf540..00daf6a6e 100644 --- a/logutil/logutil.go +++ b/logutil/logutil.go @@ -1,9 +1,12 @@ package logutil import ( + "context" "io" "log/slog" "path/filepath" + "runtime" + "time" ) const LevelTrace slog.Level = -8 @@ -27,3 +30,19 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger { }, })) } + +type key string + +func Trace(msg string, args ...any) { + TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...) +} + +func TraceContext(ctx context.Context, msg string, args ...any) { + if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) { + skip, _ := ctx.Value(key("skip")).(int) + pc, _, _, _ := runtime.Caller(1 + skip) + record := slog.NewRecord(time.Now(), LevelTrace, msg, pc) + record.Add(args...) + logger.Handler().Handle(ctx, record) + } +} diff --git a/ml/backend.go b/ml/backend.go index 705724821..351942d54 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,14 +5,11 @@ import ( "context" "encoding/binary" "fmt" - "hash/maphash" - "log/slog" "math" "slices" "strconv" "strings" - "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs" ) @@ -29,6 +26,9 @@ type Backend interface { Get(name string) Tensor NewContext() Context NewContextSize(size int) Context + + // Enumerate the devices available for inference via this backend + BackendDevices() []DeviceInfo } // BackendCacheConfig should be implemented by backends that need special output @@ -60,77 +60,6 @@ type CacheConfig struct { MaskBatchPadding int } -// GPULayers is a set of layers to be allocated on a single GPU -type GPULayers struct { - // ID is the identifier of the GPU, as reported in DeviceMemory - ID string - - // Layers is a set of layer indicies to load - Layers []int -} - -func (g GPULayers) String() string { - if len(g.Layers) == 0 { - return "" - } - - slices.Sort(g.Layers) - - contiguous := true - base := g.Layers[0] - for i := range g.Layers { - if g.Layers[i] != base+i { - contiguous = false - break - } - } - - if contiguous { - return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) - } else { - return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) - } -} - -// GPULayersList is a set of layer allocations across multiple GPUs -type GPULayersList []GPULayers - -func (l GPULayersList) String() string { - if l.Sum() > 0 { - return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) - } else { - return fmt.Sprintf("%v", []GPULayers(l)) - } -} - -// Sum is the total number of layers assigned across all GPUs -func (l GPULayersList) Sum() int { - var sum int - - for _, g := range l { - sum += len(g.Layers) - } - - return sum -} - -var h maphash.Hash - -// Hash is an identifier of this layer assignment -func (l GPULayersList) Hash() uint64 { - h.Reset() - for _, g := range l { - if len(g.Layers) > 0 { - h.WriteString(g.ID) - for _, l := range g.Layers { - binary.Write(&h, binary.NativeEndian, int64(l)) - } - } - } - - return h.Sum64() -} - // BackendParams controls how the backend loads and executes models type BackendParams struct { // AllocMemory causes the backend to allocate memory for the model. If @@ -148,201 +77,6 @@ type BackendParams struct { FlashAttention bool } -// ErrNoMem is returned when panicing due to insufficient memory. It includes -// the attempted memory allocation. -type ErrNoMem struct { - BackendMemory -} - -func (e ErrNoMem) Error() string { - return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) -} - -type AllocationStatus int - -const ( - // Unallocated memory - have not yet attempted to allocate - Unallocated AllocationStatus = iota - - // Failed memory - tried to allocate the memory and did not succeed - Failed - - // Allocated memory = tried and succeeded to allocate memory - Allocated -) - -// Memory is the size of an allocation and whether it was successful. -type Memory struct { - Size uint64 - Status AllocationStatus -} - -func (m Memory) String() string { - s := fmt.Sprint(m.Size) - - switch m.Status { - case Unallocated: - s += "U" - case Failed: - s += "F" - case Allocated: - s += "A" - } - - return s -} - -// DeviceMemory provides a breakdown of the memory needed -// per device, such as a CPU or GPU. -type DeviceMemory struct { - // Name is the name of the device as labeled by the backend. It - // may not be persistent across instances of the runner. - Name string - - // ID is an identifier for the device for matching with system - // management libraries. - ID string - - // Weights is the per-layer memory needed for the model weights. - Weights []Memory - - // Cache is the per-layer memory needed for the KV cache. - Cache []Memory - - // Graph is the size of the compute graph. It is not per-layer. - Graph Memory -} - -// Allocated returns the total size of the memory that has been successfully -// allocated on this device -func (m DeviceMemory) Allocated() uint64 { - var mem uint64 - - for _, w := range m.Weights { - if w.Status == Allocated { - mem += w.Size - } - } - for _, c := range m.Cache { - if c.Status == Allocated { - mem += c.Size - } - } - if m.Graph.Status == Allocated { - mem += m.Graph.Size - } - - return mem -} - -func memoryPresent(mem []Memory) bool { - return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 }) -} - -func (m DeviceMemory) LogValue() slog.Value { - var attrs []slog.Attr - if memoryPresent(m.Weights) { - attrs = append(attrs, slog.Any("Weights", m.Weights)) - } - - if memoryPresent(m.Cache) { - attrs = append(attrs, slog.Any("Cache", m.Cache)) - } - - if m.Graph.Size != 0 { - attrs = append(attrs, slog.Any("Graph", m.Graph)) - } - - if len(attrs) > 0 && m.ID != "" { - attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...) - } - - return slog.GroupValue(attrs...) -} - -// BackendMemory provides the amount of memory required to load the model -// per device based on the BackendParams. In some cases, not all required -// allocations will be known at this point. However, the size of the most recent -// allocation is guaranteed to be provided so that if it failed, the caller can -// accommodate that to make forward progress. -type BackendMemory struct { - // InputsWeights are always located on the CPU and cannot be moved - InputWeights Memory - - // CPU model components are located in system memory. This does not - // include unified memory allocated through the GPU. - CPU DeviceMemory - - // GPU model components are located on one or more GPUs. - GPUs []DeviceMemory -} - -func (m BackendMemory) LogValue() slog.Value { - var attrs []slog.Attr - if m.InputWeights.Size != 0 { - attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) - } - - attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU)) - for _, g := range m.GPUs { - attrs = append(attrs, slog.Any(g.Name, g)) - } - - return slog.GroupValue(attrs...) -} - -func sumMemory(mem []Memory) uint64 { - var sum uint64 - - for _, m := range mem { - sum += m.Size - } - - return sum -} - -// Log prints a high level summary of the memory (allocated or not) -func (m BackendMemory) Log(level slog.Level) { - var total uint64 - - for _, gpu := range m.GPUs { - if sum := sumMemory(gpu.Weights); sum > 0 { - slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := m.InputWeights.Size + sumMemory(m.CPU.Weights); sum > 0 { - slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - for _, gpu := range m.GPUs { - if sum := sumMemory(gpu.Cache); sum > 0 { - slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := sumMemory(m.CPU.Cache); sum > 0 { - slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - for _, gpu := range m.GPUs { - if sum := gpu.Graph.Size; sum > 0 { - slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := m.CPU.Graph.Size; sum > 0 { - slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - if total > 0 { - slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) - } -} - var backends = make(map[string]func(string, BackendParams) (Backend, error)) func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { @@ -372,6 +106,7 @@ type Context interface { Forward(...Tensor) Context Compute(...Tensor) + ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun // Reserve is analogous to Compute but rather than executing a // graph, simply preallocates memory. Typically called with a @@ -401,6 +136,8 @@ type Tensor interface { Bytes() []byte Floats() []float32 + SetValueFromIntSlice(s []int32) + Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor @@ -413,6 +150,7 @@ type Tensor interface { AddID(ctx Context, t2, ids Tensor) Tensor Softmax(ctx Context) Tensor + L2Norm(ctx Context, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor @@ -426,12 +164,13 @@ type Tensor interface { Sin(ctx Context) Tensor Cos(ctx Context) Tensor Tanh(ctx Context) Tensor - GELU(ctx Context) Tensor - QuickGELU(ctx Context) Tensor - SILU(ctx Context) Tensor - RELU(ctx Context) Tensor + GELU(ctx Context, up ...Tensor) Tensor + SILU(ctx Context, up ...Tensor) Tensor + RELU(ctx Context, up ...Tensor) Tensor Sigmoid(ctx Context) Tensor - SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor + + // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit] + SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor Reshape(ctx Context, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 13d898aad..dc71c8de4 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1,5 +1,7 @@ package ggml +// #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm +// #cgo windows LDFLAGS: -lpthread // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include // #include // #include @@ -82,6 +84,7 @@ type Backend struct { // to the name that is used by the model definition tensorLoadTargets map[string][]string + schedMu sync.Mutex // Only one Compute can run at a time sched C.ggml_backend_sched_t schedBackends []C.ggml_backend_t schedBufts []C.ggml_backend_buffer_type_t @@ -158,7 +161,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { C.GGML_BACKEND_DEVICE_TYPE_ACCEL: bt := C.ggml_backend_dev_buffer_type(d) cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, bt) - C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory)) btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU } @@ -168,8 +170,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props) requiredMemory.CPU.ID = C.GoString(props.id) - requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1) - requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1) + requiredMemory.CPU.Library = C.GoString(props.library) + requiredMemory.CPU.Weights = make([]uint64, blocks+1) + requiredMemory.CPU.Cache = make([]uint64, blocks+1) // create list of buffer types for each gpu var gpuDeviceBufferTypes []deviceBufferType @@ -180,15 +183,15 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { d: d, bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...), }) - C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory)) btDeviceMemory[bt] = &requiredMemory.GPUs[i] requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(d, &props) requiredMemory.GPUs[i].ID = C.GoString(props.id) - requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1) - requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) + requiredMemory.GPUs[i].Library = C.GoString(props.library) + requiredMemory.GPUs[i].Weights = make([]uint64, blocks+1) + requiredMemory.GPUs[i].Cache = make([]uint64, blocks+1) } // inputs always use cpu @@ -199,7 +202,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { for _, l := range p.Layers { if l == layer { for i := range requiredMemory.GPUs { - if requiredMemory.GPUs[i].ID == p.ID { + if requiredMemory.GPUs[i].DeviceID == p.DeviceID { return gpuDeviceBufferTypes[i] } } @@ -270,17 +273,13 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0]))) C.ggml_set_name(tt, cname) - slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) + logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) if layer == -1 { - // Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case - if params.AllocMemory { - requiredMemory.InputWeights.Status = ml.Allocated - } - requiredMemory.InputWeights.Size += uint64(size) + requiredMemory.InputWeights += uint64(size) } else { - btDeviceMemory[bt].Weights[layer].Size += uint64(size) + btDeviceMemory[bt].Weights[layer] += uint64(size) } //nolint:staticcheck // TODO: check if buffer type supports this tensor @@ -340,47 +339,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } - // allocate buffers for each context - bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs)) - for bt, c := range ctxs { - if C.ggml_get_first_tensor(c) == nil { - continue - } - - b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) - if params.AllocMemory { - for i := range btDeviceMemory[bt].Weights { - if btDeviceMemory[bt].Weights[i].Size != 0 { - if b != nil { - btDeviceMemory[bt].Weights[i].Status = ml.Allocated - } else { - btDeviceMemory[bt].Weights[i].Status = ml.Failed - } - } - } - } - - if b == nil { - for _, b := range bbs { - C.ggml_backend_buffer_free(b) - } - - for _, ctx := range ctxs { - C.ggml_free(ctx) - } - - panic(ml.ErrNoMem{BackendMemory: requiredMemory}) - } - - C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) - bbs[c] = b - } - - for bs := range maps.Values(bbs) { - slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), - "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) - } - // map tensor names to tensors for easy lookup later tensors := make(map[string]*C.struct_ggml_tensor) for _, c := range ctxs { @@ -418,6 +376,46 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) + + sched := C.ggml_backend_sched_new_ext( + (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), + (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), + C.int(len(schedBackends)), + C.size_t(maxGraphNodes), + C._Bool(false), + C._Bool(false), + C._Bool(params.AllocMemory), + ) + + // allocate buffers for each context + bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs)) + for bt, c := range ctxs { + if C.ggml_get_first_tensor(c) == nil { + continue + } + + b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) + if b == nil { + for _, b := range bbs { + C.ggml_backend_buffer_free(b) + } + + for _, ctx := range ctxs { + C.ggml_free(ctx) + } + + panic(ml.ErrNoMem{BackendMemory: requiredMemory}) + } + + C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) + bbs[c] = b + } + + for bs := range maps.Values(bbs) { + logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), + "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) + } + return &Backend{ modelPath: modelPath, allocMemory: params.AllocMemory, @@ -425,18 +423,11 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { meta: meta, tensorLoadTargets: targets, tensors: tensors, - sched: C.ggml_backend_sched_new( - (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), - (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), - C.int(len(schedBackends)), - C.size_t(maxGraphNodes), - C._Bool(false), - C._Bool(false), - ), - schedBackends: schedBackends, - schedBufts: schedBufts, - input: deviceBufferTypes[input.d], - output: output.d, + sched: sched, + schedBackends: schedBackends, + schedBufts: schedBufts, + input: deviceBufferTypes[input.d], + output: output.d, layers: func() map[int]layerDevice { m := make(map[int]layerDevice) for i, layer := range layers { @@ -535,6 +526,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { const BS = 17 // MXFP4 block size bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned var s uint64 + var tmp [16]byte for s < t.Size() { // Stop if either the parent context has been canceled or if any of the other tensors returned an error if err := ctx.Err(); err != nil { @@ -546,37 +538,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { return err } for j := range n / BS { - for i := 1; i < BS; i++ { - // swap nibbles - t_lo := bts[j*BS+i] & 0x0F - t_hi := bts[j*BS+i] & 0xF0 - bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4) - } - // transform aaaa...bbbb... to abababab... - oi := 0 - tmp := [16]byte{} for i := 1; i < 9; i++ { - blk_a0 := bts[j*BS+i] & 0xF0 - blk_a1 := bts[j*BS+i] << 4 - blk_b0 := bts[j*BS+i+8] >> 4 - blk_b1 := bts[j*BS+i+8] & 0x0F - // swap once more - out0 := blk_a0 | blk_b0 - out1 := blk_a1 | blk_b1 - out_h0 := out0 & 0xF0 - out_l0 := out0 & 0x0F - out_h1 := out1 & 0xF0 - out_l1 := out1 & 0x0F - out0 = (out_h0 >> 4) | (out_l0 << 4) - out1 = (out_h1 >> 4) | (out_l1 << 4) - tmp[oi] = out0 - oi++ - tmp[oi] = out1 - oi++ - } - for i := range tmp { - bts[j*BS+i+1] = tmp[i] + // transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc + a, b := bts[j*BS+i], bts[j*BS+i+8] + tmp[2*(i-1)] = (a & 0x0F) | (b << 4) + tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0) } + copy(bts[j*BS+1:j*BS+17], tmp[:]) } for _, tt := range tts { @@ -652,6 +620,18 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { }) } + // Cleanup any backend state from devices that we didn't end up using +nextDevice: + for _, d := range append(gpus, append(accels, cpus...)...) { + for _, backend := range b.schedBackends { + if d == C.ggml_backend_get_device(backend) { + continue nextDevice + } + } + + C.ggml_backend_dev_reset(d) + } + if err := g.Wait(); err != nil { return err } @@ -706,6 +686,52 @@ func (b *Backend) CacheConfig() ml.CacheConfig { } } +func (b *Backend) BackendDevices() []ml.DeviceInfo { + deviceInfos := []ml.DeviceInfo{} + for _, dev := range gpus { + // If we have a model loaded, and it's only loaded on a subset of the devices + // skip idle/unused devices to avoid initializing them and causing VRAM allocations + if b.allocMemory { + idleDev := true + for _, backend := range b.schedBackends { + if dev == C.ggml_backend_get_device(backend) { + idleDev = false + break + } + } + if idleDev { + slog.Debug("skipping unused backend device", "description", C.GoString(C.ggml_backend_dev_description(dev))) + continue + } + } + + info := ml.DeviceInfo{} + props := C.struct_ggml_backend_dev_props{} + C.ggml_backend_dev_get_props(dev, &props) + info.Name = C.GoString(props.name) + info.Description = C.GoString(props.description) + info.ID = C.GoString(props.id) + info.Library = C.GoString(props.library) + info.ComputeMajor = (int)(props.compute_major) + info.ComputeMinor = (int)(props.compute_minor) + info.DriverMajor = (int)(props.driver_major) + info.DriverMinor = (int)(props.driver_minor) + info.Integrated = props.integrated != 0 + if props.library != nil { + info.Library = C.GoString(props.library) + } + info.PCIID = fmt.Sprintf("%02x:%02x.%x", props.pci_bus_id, props.pci_device_id, props.pci_domain_id) + info.LibraryPath = ggml.LibPaths() + + C.ggml_backend_dev_memory(dev, &props.memory_free, &props.memory_total) + info.TotalMemory = (uint64)(props.memory_total) + info.FreeMemory = (uint64)(props.memory_free) + + deviceInfos = append(deviceInfos, info) + } + return deviceInfos +} + type Context struct { b *Backend @@ -769,6 +795,15 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { } func (c *Context) Compute(tensors ...ml.Tensor) { + c.ComputeWithNotify(nil, tensors...) +} + +func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) { + c.b.schedMu.Lock() + defer c.b.schedMu.Unlock() + if cb != nil { + go cb() + } if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS { panic(fmt.Errorf("error computing ggml graph: %v", status)) } @@ -796,24 +831,15 @@ func (c *Context) Reserve() { // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations for _, bt := range c.b.schedBufts { - c.b.btDeviceMemory[bt].Graph = ml.Memory{} + c.b.btDeviceMemory[bt].Graph = 0 } for i := range c.b.schedBackends { - bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) + bufferSize := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) + c.b.btDeviceMemory[c.b.schedBufts[i]].Graph += uint64(bufferSize) - graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph - graph.Size += uint64(bufferStatus.size) - if c.b.allocMemory { - if bufferStatus.allocated && graph.Status != ml.Failed { - graph.Status = ml.Allocated - } else { - graph.Status = ml.Failed - } - } - - slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), - "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size))) + logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), + "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferSize))) } if !reserved { @@ -863,16 +889,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { b := C.ggml_backend_buft_alloc_buffer(c.buft, size) if c.layer >= 0 { - cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer] - - cache.Size += uint64(size) - if c.b.allocMemory { - if b != nil { - cache.Status = ml.Allocated - } else { - cache.Status = ml.Failed - } - } + c.b.btDeviceMemory[c.buft].Cache[c.layer] += uint64(size) } if b == nil { @@ -1021,6 +1038,12 @@ func (t *Tensor) Floats() (data []float32) { return } +func (t *Tensor) SetValueFromIntSlice(s []int32) { + if len(s) > 0 { + C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t)) + } +} + func (t *Tensor) DType() ml.DType { switch t.t._type { case C.GGML_TYPE_F32: @@ -1200,6 +1223,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { } } +func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)), + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) if w != nil { @@ -1419,35 +1449,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } -func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t), +func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } } -} - -func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) RELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { +func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)), diff --git a/ml/backend/ggml/ggml/include/ggml-alloc.h b/ml/backend/ggml/ggml/include/ggml-alloc.h index 781b1e100..7ab3f0192 100644 --- a/ml/backend/ggml/ggml/include/ggml-alloc.h +++ b/ml/backend/ggml/ggml/include/ggml-alloc.h @@ -65,12 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n( GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); - -struct ggml_allocr_buffer_status { - size_t size; - bool allocated; -}; -GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); +GGML_API size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); // Utils // Create a buffer and allocate all the tensors in a ggml_context diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index b602a7c78..38418c4cd 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -35,7 +35,6 @@ extern "C" { // GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); - GGML_API void ggml_backend_buft_set_alloc (ggml_backend_buffer_type_t buft, bool alloc); GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); @@ -158,6 +157,15 @@ extern "C" { size_t memory_total; enum ggml_backend_dev_type type; struct ggml_backend_dev_caps caps; + int driver_major; + int driver_minor; + int compute_major; + int compute_minor; + int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; + const char *library; }; GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); @@ -167,6 +175,7 @@ extern "C" { GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); + GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); @@ -292,6 +301,7 @@ extern "C" { // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); + GGML_API ggml_backend_sched_t ggml_backend_sched_new_ext(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload, bool alloc_buffers); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph @@ -305,12 +315,7 @@ extern "C" { GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); - - struct ggml_backend_buffer_status { - size_t size; - bool allocated; - }; - GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index 5158acd6a..3a428a22d 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -203,6 +203,8 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h + mem_hip.cpp + mem_nvml.cpp gguf.cpp) target_include_directories(ggml-base PRIVATE .) diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index 41c8c4a2f..b58bd671d 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -932,7 +932,7 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); } -struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { +size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); for (int i = 0; i < buffer_id; i++) { @@ -941,13 +941,11 @@ struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gal // (See above.) However, we need a different check because multiple buffers might be NULL in our // case and we still want to know the attempted size. - struct ggml_allocr_buffer_status status = {0, true}; - return status; + return 0; } } - struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; - return status; + return galloc->buffer_sizes[buffer_id]; } // utils diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h index 81749a5a3..272571867 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h @@ -26,6 +26,10 @@ extern "C" { size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) bool (*is_host) (ggml_backend_buffer_type_t buft); + + // (optional) returns a dummy buffer that is equivalent to one created by alloc_buffer but without actually being backed + // by memory + ggml_backend_buffer_t (*noalloc_buffer)(ggml_backend_buffer_type_t buft, size_t size); }; struct ggml_backend_buffer_type { @@ -116,6 +120,16 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + // (optional) reserves intermediate buffers needed for the compution + // if alloc is true, memory is actually allocated, otherwise the required amount is just returned by buffer_size + enum ggml_status (*graph_reserve) (ggml_backend_t backend, struct ggml_cgraph * cgraph, bool alloc); + + // (optional) returns the memory needed after calling graph_reserve + size_t (*buffer_size) (ggml_backend_t backend); + + // (optional) frees memory from intermediate buffers that was allocated either by graph_compute or graph_reserve + void (*reset) (ggml_backend_t backend); }; struct ggml_backend { @@ -178,6 +192,10 @@ extern "C" { ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev); void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event); void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event); + + // (optional) reset device, clearing existing allocations and context + // the caller must ensure that there are no outstanding buffers, as these will become invalid + void (*reset)(ggml_backend_dev_t dev); }; struct ggml_backend_device { diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index 05a842ed5..5f99948d5 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -35,10 +35,6 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { return buft->iface.get_name(buft); } -void ggml_backend_buft_set_alloc(ggml_backend_buffer_type_t buft, bool alloc) { - buft->no_alloc = !alloc; -} - ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { if (size == 0) { // return a dummy buffer for zero-sized allocations @@ -46,7 +42,14 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t } if (buft->no_alloc) { - ggml_backend_buffer_t buf = ggml_backend_buffer_init(buft, {}, NULL, size); + ggml_backend_buffer_t buf; + + if (buft->iface.noalloc_buffer != NULL) { + buf = buft->iface.noalloc_buffer(buft, size); + } else { + buf = ggml_backend_buffer_init(buft, {}, NULL, size); + } + buf->no_alloc = true; return buf; } @@ -477,6 +480,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par return device->iface.init_backend(device, params); } +void ggml_backend_dev_reset(ggml_backend_dev_t device) { + if (device->iface.reset == NULL) { + return; + } + + device->iface.reset(device); +} + ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { return device->iface.get_buffer_type(device); } @@ -680,6 +691,12 @@ struct ggml_backend_sched { bool op_offload; int debug; + + // allocate buffers on attached ggml_backend_buffer_type_t's and during reservation + // if false, dummy buffers are used for faster memory sizing calculations + // the scheduler needs to be recreated with allocated buffers before it can be used + // for computation + bool alloc_buffers; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1466,6 +1483,17 @@ ggml_backend_sched_t ggml_backend_sched_new( size_t graph_size, bool parallel, bool op_offload) { + return ggml_backend_sched_new_ext(backends, bufts, n_backends, graph_size, parallel, op_offload, true); + } + +ggml_backend_sched_t ggml_backend_sched_new_ext( + ggml_backend_t * backends, + ggml_backend_buffer_type_t * bufts, + int n_backends, + size_t graph_size, + bool parallel, + bool op_offload, + bool alloc_buffers) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); @@ -1507,10 +1535,13 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->events[b][c] = ggml_backend_event_new(backends[b]->device); } } + + sched->bufts[b]->no_alloc = !alloc_buffers; } sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); sched->op_offload = op_offload; + sched->alloc_buffers = alloc_buffers; ggml_backend_sched_reset(sched); @@ -1525,6 +1556,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { for (int c = 0; c < sched->n_copies; c++) { ggml_backend_event_free(sched->events[b][c]); } + + if (sched->backends[b]->iface.reset != NULL) { + sched->backends[b]->iface.reset(sched->backends[b]); + } } ggml_gallocr_free(sched->galloc); ggml_free(sched->ctx); @@ -1564,6 +1599,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * return false; } + if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { + return false; + } + + struct ggml_backend_sched_split * splits = sched->splits; + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &splits[i]; + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; + + if (split_backend->iface.graph_reserve != NULL) { + enum ggml_status ec = split_backend->iface.graph_reserve(split_backend, &split->graph, sched->alloc_buffers); + if (ec != GGML_STATUS_SUCCESS) { + return false; + } + } + } + ggml_backend_sched_reset(sched); return true; @@ -1648,14 +1701,17 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } -struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { +size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); - struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); - struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; + size_t size = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); - return status; + if (backend->iface.buffer_size != NULL) { + size += backend->iface.buffer_size(backend); + } + + return size; } void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index 2e5d48797..b915ee1b8 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -35,6 +35,31 @@ #include "vendors/cuda.h" #endif // defined(GGML_USE_HIP) +extern bool reserving_graph; + +// If we are reserving the graph, pointers might be invalid and will fail if cudaMemcpyAsync tries to validate them. +// However, since we don't actually expect a result, we don't need to actually do the memcpy. +static cudaError_t cudaMemcpyAsyncReserve ( void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { + if (!reserving_graph) { + return cudaMemcpyAsync(dst, src, count, kind, stream); + } else { + return cudaSuccess; + } +} + +static cudaError_t cudaMemcpy2DAsyncReserve ( void* dst, size_t dpitch, const void* src, size_t spitch, size_t width, size_t height, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { + if (!reserving_graph) { + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, kind, stream); + } else { + return cudaSuccess; + } +} + +#undef cudaMemcpyAsync +#define cudaMemcpyAsync cudaMemcpyAsyncReserve +#undef cudaMemcpy2DAsync +#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve + #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) @@ -771,6 +796,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; + + virtual bool alloc_memory() = 0; + virtual size_t alloc_size() = 0; }; template @@ -914,11 +942,11 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; - static std::unique_ptr new_pool_for_device(int device); + static std::unique_ptr new_pool_for_device(int device, bool alloc); ggml_cuda_pool & pool(int device) { if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(device); + pools[device] = new_pool_for_device(device, true); } return *pools[device]; } @@ -926,4 +954,20 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } + + void pool_set_alloc(bool alloc) { + GGML_ASSERT(pools[device] == nullptr || pools[device]->alloc_memory() == alloc); + + if (pools[device] == nullptr) { + pools[device] = new_pool_for_device(device, alloc); + } + } + + size_t pool_get_alloc_size() { + if (pools[device] == nullptr) { + return 0; + } + + return pools[device]->alloc_size(); + } }; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index c7f9dc3a5..cc52361fe 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -103,6 +103,11 @@ int ggml_cuda_get_device() { return id; } +void ggml_cuda_reset_device(int device) { + ggml_cuda_set_device(device); + CUDA_CHECK(cudaDeviceReset()); +} + static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; @@ -274,6 +279,16 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; +#if defined(GGML_USE_HIP) + if (std::getenv("GGML_CUDA_INIT") != NULL) { + GGML_LOG_INFO("%s: initializing rocBLAS on device %d\n", __func__, id); + CUDA_CHECK(cudaSetDevice(id)); + // rocblas_initialize will SIGABRT if the GPU isn't supported + rocblas_initialize(); + GGML_LOG_INFO("%s: rocBLAS initialized on device %d\n", __func__, id); + } +#endif + #if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); @@ -327,9 +342,15 @@ static ggml_cuda_device_info ggml_cuda_init() { #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; +#ifdef __CUDA_ARCH_LIST__ + if (std::getenv("GGML_CUDA_INIT") != NULL) { + GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); + } +#endif // defined(__CUDA_ARCH_LIST__) GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", ggml_cuda_parse_uuid(prop, id).c_str()); + #endif // defined(GGML_USE_HIP) } @@ -350,6 +371,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { // #define DEBUG_CUDA_MALLOC +#define CUDA_ALIGNMENT 128 + // buffer pool for cuda (legacy) struct ggml_cuda_pool_leg : public ggml_cuda_pool { static const int MAX_BUFFERS = 256; @@ -362,9 +385,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; size_t pool_size = 0; + bool allocate = true; + size_t last_alloc = 0; - explicit ggml_cuda_pool_leg(int device) : - device(device) { + explicit ggml_cuda_pool_leg(int device, bool alloc) : + device(device), + allocate(alloc) { } ~ggml_cuda_pool_leg() { @@ -372,7 +398,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - CUDA_CHECK(cudaFree(b.ptr)); + if (allocate) { + CUDA_CHECK(cudaFree(b.ptr)); + } pool_size -= b.size; } } @@ -420,8 +448,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); - ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + if (allocate) { + ggml_cuda_set_device(device); + if (ggml_cuda_device_malloc(&ptr, look_ahead_size, device) != cudaSuccess) { + last_alloc = look_ahead_size; + throw std::bad_alloc(); + } + } else { + ptr = (void *)CUDA_ALIGNMENT; + } *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -441,10 +476,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } } GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); - ggml_cuda_set_device(device); - CUDA_CHECK(cudaFree(ptr)); + if (allocate) { + ggml_cuda_set_device(device); + CUDA_CHECK(cudaFree(ptr)); + } pool_size -= size; } + + bool alloc_memory() override { + return allocate; + } + + size_t alloc_size() override { + return pool_size + last_alloc; + } }; // pool with virtual memory @@ -456,18 +501,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { CUdeviceptr pool_addr = 0; size_t pool_used = 0; size_t pool_size = 0; + bool allocate = true; + size_t last_alloc = 0; size_t granularity; #if defined(GGML_USE_HIP) std::vector> mappings; #endif - explicit ggml_cuda_pool_vmm(int device) : + explicit ggml_cuda_pool_vmm(int device, bool alloc) : device(device), - granularity(ggml_cuda_info().devices[device].vmm_granularity) { + granularity(ggml_cuda_info().devices[device].vmm_granularity), + allocate(alloc) { + if (!allocate) { + pool_addr = (CUdeviceptr)CUDA_ALIGNMENT; + } } ~ggml_cuda_pool_vmm() { - if (pool_addr != 0) { + if (pool_addr != 0 && allocate) { #if defined(GGML_USE_HIP) // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 for (std::pair & mapping : mappings) { @@ -494,36 +545,50 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); - // allocate more physical memory - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = device; - CUmemGenericAllocationHandle handle; - CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); + if (allocate) { + // allocate more physical memory + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + CUmemGenericAllocationHandle handle; + if (cuMemCreate(&handle, reserve_size, &prop, 0) != CUDA_SUCCESS) { + last_alloc = reserve_size; + throw std::bad_alloc(); + } - // reserve virtual address space (if not already reserved) - if (pool_addr == 0) { - CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); + } + + // map at the end of the pool + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + if (cuMemMap(start_ptr, reserve_size, 0, handle, 0) != CUDA_SUCCESS) { + last_alloc = reserve_size; + CU_CHECK(cuMemRelease(handle)); + throw std::bad_alloc(); + } + + // the memory allocation handle is no longer needed after mapping + CU_CHECK(cuMemRelease(handle)); + + // set access + CUmemAccessDesc access = {}; + access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access.location.id = device; + access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + if (cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1) != CUDA_SUCCESS) { + CU_CHECK(cuMemUnmap(start_ptr, reserve_size)); + last_alloc = reserve_size; + throw std::bad_alloc(); + } + + #if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); + #endif } - // map at the end of the pool - CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); - CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); -#if defined(GGML_USE_HIP) - mappings.push_back({start_ptr, reserve_size}); -#endif - - // the memory allocation handle is no longer needed after mapping - CU_CHECK(cuMemRelease(handle)); - - // set access - CUmemAccessDesc access = {}; - access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access.location.id = device; - access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); - // add to the pool pool_size += reserve_size; @@ -555,16 +620,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { // all deallocations must be in reverse order of the allocations GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } + + bool alloc_memory() override { + return allocate; + } + + size_t alloc_size() override { + return pool_size + last_alloc; + } }; #endif // defined(GGML_USE_VMM) -std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { +std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, bool alloc) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { - return std::unique_ptr(new ggml_cuda_pool_vmm(device)); + return std::unique_ptr(new ggml_cuda_pool_vmm(device, alloc)); } #endif // defined(GGML_USE_VMM) - return std::unique_ptr(new ggml_cuda_pool_leg(device)); + return std::unique_ptr(new ggml_cuda_pool_leg(device, alloc)); } // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error @@ -748,11 +821,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac } static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return CUDA_ALIGNMENT; GGML_UNUSED(buft); } +static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_noalloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; + + void * dev_ptr = (void *)ggml_backend_cuda_buffer_type_get_alignment(buft); + ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); + + return ggml_backend_buffer_init(buft, {}, ctx, size); +} + static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; @@ -776,6 +858,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .is_host = */ NULL, + /* .noalloc_buffer = */ ggml_backend_cuda_buffer_type_noalloc_buffer, }; ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { @@ -2936,6 +3019,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { + // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; @@ -2951,6 +3035,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + // When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it + if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { + continue; + } + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { @@ -3022,6 +3111,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + cuda_ctx->pool_set_alloc(true); ggml_cuda_set_device(cuda_ctx->device); @@ -3101,6 +3191,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, return GGML_STATUS_SUCCESS; } +// This is used to skip operations that are not graph safe during the reservation process. +bool reserving_graph = false; + +static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + cuda_ctx->pool_set_alloc(alloc); + + #ifdef USE_CUDA_GRAPH + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + #endif + + ggml_cuda_set_device(cuda_ctx->device); + + { + std::lock_guard lock(ggml_cuda_lock); + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + } + + reserving_graph = true; + + // Create CuBLAS handles early to avoid synchronous allocations during graph capture. + cuda_ctx->cublas_handle(); + + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + + enum ggml_status result = GGML_STATUS_SUCCESS; + + try { + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; + bool graph_evaluated_or_captured = false; + + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + } catch (const std::exception &e) { + result = GGML_STATUS_FAILED; + } + + cudaGraph_t graph; + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph)); + CUDA_CHECK(cudaGraphDestroy(graph)); + + reserving_graph = false; + + { + std::lock_guard lock(ggml_cuda_lock); + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { + ggml_cuda_lock_cv.notify_all(); + } + } + + return result; +} + +static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) { + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; + return ctx->pool_get_alloc_size(); +} + +static void ggml_backend_cuda_reset(ggml_backend_t backend) { + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; + ctx->pools[ctx->device] = NULL; +} + static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -3140,6 +3295,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .graph_reserve = */ ggml_backend_cuda_graph_reserve, + /* .buffer_size = */ ggml_backend_cuda_buffer_size, + /* .reset = */ ggml_backend_cuda_reset, }; static ggml_guid_t ggml_backend_cuda_guid() { @@ -3210,6 +3368,14 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string id; + int major; + int minor; + int driver_major; + int driver_minor; + int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -3230,6 +3396,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); + +#if defined(GGML_USE_HIP) + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } +#else + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->id.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } +#endif CUDA_CHECK(cudaMemGetInfo(free, total)); } @@ -3238,12 +3426,33 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } +#define GGML_HIP_NAME "HIP" static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); - ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); + + // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device). + // If you need the memory data, call ggml_backend_dev_memory() explicitly. + props->memory_total = props->memory_free = 0; + + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; +#if defined(GGML_USE_HIP) + int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; + props->compute_major = cc / 0x100; + props->compute_minor = cc - (props->compute_major * 0x100); +#else + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; +#endif + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->integrated; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; + props->library = GGML_CUDA_NAME; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY @@ -3700,6 +3909,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } +static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_reset_device(ctx->device); +} + static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, @@ -3716,6 +3930,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, + /* .reset = */ ggml_backend_cuda_device_reset, }; // backend reg @@ -3829,18 +4044,26 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + int driverVersion = 0; + CUDA_CHECK(cudaDriverGetVersion(&driverVersion)); for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; dev_ctx->device = i; dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); - ggml_cuda_set_device(i); cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; dev_ctx->id = ggml_cuda_parse_uuid(prop, i); - + dev_ctx->major = prop.major; + dev_ctx->minor = prop.minor; + dev_ctx->driver_major = driverVersion / 1000; + dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; + dev_ctx->integrated = prop.integrated; + dev_ctx->pci_bus_id = prop.pciBusID; + dev_ctx->pci_device_id = prop.pciDeviceID; + dev_ctx->pci_domain_id = prop.pciDomainID; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu b/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu index 2ee9e5889..0ddeff6a1 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,19 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +#define MAX_GRIDDIM_X 0x7FFFFFFF - if (i >= k) { - return; +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; } - - dst[i] = scale * x[i] + bias; } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, k); +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index c31f31923..957a795f2 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -40,7 +40,9 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaDriverGetVersion hipDriverGetVersion #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h index 19a7adb2d..b9b102a5e 100644 --- a/ml/backend/ggml/ggml/src/ggml-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-impl.h @@ -602,6 +602,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return true; } +// Management libraries for fetching more accurate free VRAM data +GGML_API int ggml_nvml_init(); +GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); +GGML_API void ggml_nvml_release(); +GGML_API int ggml_hip_mgmt_init(); +GGML_API int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total); +GGML_API void ggml_hip_mgmt_release(); + #ifdef __cplusplus } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index e4c31268f..5451483de 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -6523,12 +6523,14 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen GGML_UNUSED(dev); } +#define GGML_METAL_NAME "Metal" static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { props->name = ggml_backend_metal_device_get_name(dev); props->description = ggml_backend_metal_device_get_description(dev); props->id = "0"; props->type = ggml_backend_metal_device_get_type(dev); ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->library = GGML_METAL_NAME; props->caps = (struct ggml_backend_dev_caps) { /* .async = */ false, /* .host_buffer = */ false, diff --git a/ml/backend/ggml/ggml/src/ggml.cpp b/ml/backend/ggml/ggml/src/ggml.cpp index 0d388d455..f5bcb446d 100644 --- a/ml/backend/ggml/ggml/src/ggml.cpp +++ b/ml/backend/ggml/ggml/src/ggml.cpp @@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{ return false; } const auto prev{std::get_terminate()}; - GGML_ASSERT(prev != ggml_uncaught_exception); - previous_terminate_handler = prev; + // GGML_ASSERT(prev != ggml_uncaught_exception); + if (prev != ggml_uncaught_exception) { + previous_terminate_handler = prev; + } else { + GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__); + } std::set_terminate(ggml_uncaught_exception); return true; }(); diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go index 37347807d..7e215916e 100644 --- a/ml/backend/ggml/ggml/src/ggml.go +++ b/ml/backend/ggml/ggml/src/ggml.go @@ -75,9 +75,9 @@ var OnceLoad = sync.OnceFunc(func() { paths = value } - split := filepath.SplitList(paths) - visited := make(map[string]struct{}, len(split)) - for _, path := range split { + libPaths = filepath.SplitList(paths) + visited := make(map[string]struct{}, len(libPaths)) + for _, path := range libPaths { abspath, err := filepath.Abs(path) if err != nil { slog.Error("failed to get absolute path", "error", err) @@ -104,6 +104,12 @@ var OnceLoad = sync.OnceFunc(func() { slog.Info("system", "", system{}) }) +var libPaths []string + +func LibPaths() []string { + return libPaths +} + type system struct{} func (system) LogValue() slog.Value { diff --git a/ml/backend/ggml/ggml/src/mem_hip.cpp b/ml/backend/ggml/ggml/src/mem_hip.cpp new file mode 100644 index 000000000..8ef19b8cf --- /dev/null +++ b/ml/backend/ggml/ggml/src/mem_hip.cpp @@ -0,0 +1,449 @@ +#include "ggml.h" + +#ifdef _WIN32 +// AMD Device Library eXtra (ADLX) +// +// https://github.com/GPUOpen-LibrariesAndSDKs/ADLX +// +// This Windows-only library provides accurate VRAM reporting for AMD GPUs. +// The runtime DLL is installed with every AMD Driver on Windows, however +// the SDK isn't a part of the HIP SDK packaging. As such, we avoid including +// the headers from the SDK to simplify building from source. +// +// ADLX relies heavily on function pointer tables. +// Only the minimal set of types are defined below to facilitate +// finding the target AMD GPU(s) and querying their current VRAM usage +// Unused function parameters are commented out to avoid unnecessary type +// definitions. + +#include "ggml-impl.h" +#include +#include + +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include + +namespace fs = std::filesystem; + +#include +#include + +// Begin minimal ADLX definitions - derived from tag v1.0 (Dec 2022) +typedef uint64_t adlx_uint64; +typedef uint32_t adlx_uint32; +typedef int32_t adlx_int32; +typedef adlx_int32 adlx_int; +typedef adlx_uint32 adlx_uint; +typedef long adlx_long; +typedef uint8_t adlx_uint8; +typedef enum +{ + ADLX_OK = 0, /**< @ENG_START_DOX This result indicates success. @ENG_END_DOX */ + ADLX_ALREADY_ENABLED, /**< @ENG_START_DOX This result indicates that the asked action is already enabled. @ENG_END_DOX */ + ADLX_ALREADY_INITIALIZED, /**< @ENG_START_DOX This result indicates that ADLX has a unspecified type of initialization. @ENG_END_DOX */ + ADLX_FAIL, /**< @ENG_START_DOX This result indicates an unspecified failure. @ENG_END_DOX */ + ADLX_INVALID_ARGS, /**< @ENG_START_DOX This result indicates that the arguments are invalid. @ENG_END_DOX */ + ADLX_BAD_VER, /**< @ENG_START_DOX This result indicates that the asked version is incompatible with the current version. @ENG_END_DOX */ + ADLX_UNKNOWN_INTERFACE, /**< @ENG_START_DOX This result indicates that an unknown interface was asked. @ENG_END_DOX */ + ADLX_TERMINATED, /**< @ENG_START_DOX This result indicates that the calls were made in an interface after ADLX was terminated. @ENG_END_DOX */ + ADLX_ADL_INIT_ERROR, /**< @ENG_START_DOX This result indicates that the ADL initialization failed. @ENG_END_DOX */ + ADLX_NOT_FOUND, /**< @ENG_START_DOX This result indicates that the item is not found. @ENG_END_DOX */ + ADLX_INVALID_OBJECT, /**< @ENG_START_DOX This result indicates that the method was called into an invalid object. @ENG_END_DOX */ + ADLX_ORPHAN_OBJECTS, /**< @ENG_START_DOX This result indicates that ADLX was terminated with outstanding ADLX objects. Any interface obtained from ADLX points to invalid memory and calls in their methods will result in unexpected behavior. @ENG_END_DOX */ + ADLX_NOT_SUPPORTED, /**< @ENG_START_DOX This result indicates that the asked feature is not supported. @ENG_END_DOX */ + ADLX_PENDING_OPERATION, /**< @ENG_START_DOX This result indicates a failure due to an operation currently in progress. @ENG_END_DOX */ + ADLX_GPU_INACTIVE /**< @ENG_START_DOX This result indicates that the GPU is inactive. @ENG_END_DOX */ +} ADLX_RESULT; +#define ADLX_SUCCEEDED(x) (ADLX_OK == (x) || ADLX_ALREADY_ENABLED == (x) || ADLX_ALREADY_INITIALIZED == (x)) +#define ADLX_FAILED(x) (ADLX_OK != (x) && ADLX_ALREADY_ENABLED != (x) && ADLX_ALREADY_INITIALIZED != (x)) +#define ADLX_VER_MAJOR 1 +#define ADLX_VER_MINOR 0 +#define ADLX_VER_RELEASE 5 +#define ADLX_VER_BUILD_NUM 30 +#define ADLX_MAKE_FULL_VER(VERSION_MAJOR, VERSION_MINOR, VERSION_RELEASE, VERSION_BUILD_NUM) ( ((adlx_uint64)(VERSION_MAJOR) << 48ull) | ((adlx_uint64)(VERSION_MINOR) << 32ull) | ((adlx_uint64)(VERSION_RELEASE) << 16ull) | (adlx_uint64)(VERSION_BUILD_NUM)) +#define ADLX_FULL_VERSION ADLX_MAKE_FULL_VER(ADLX_VER_MAJOR, ADLX_VER_MINOR, ADLX_VER_RELEASE, ADLX_VER_BUILD_NUM) +#define ADLX_CORE_LINK __declspec(dllexport) +#define ADLX_STD_CALL __stdcall +#define ADLX_CDECL_CALL __cdecl +#define ADLX_FAST_CALL __fastcall +#define ADLX_INLINE __inline +#define ADLX_FORCEINLINE __forceinline +#define ADLX_NO_VTABLE __declspec(novtable) + +#if defined(__cplusplus) +typedef bool adlx_bool; +#else +typedef adlx_uint8 adlx_bool; +#define true 1 +#define false 0 +#endif + +typedef struct IADLXSystem IADLXSystem; +typedef struct IADLXGPUList IADLXGPUList; +typedef struct IADLXGPU IADLXGPU; +typedef struct IADLXInterface IADLXInterface; +typedef struct IADLXPerformanceMonitoringServices IADLXPerformanceMonitoringServices; +typedef struct IADLXGPUMetrics IADLXGPUMetrics; +typedef struct IADLXGPUMetricsSupport IADLXGPUMetricsSupport; + +typedef struct IADLXSystemVtbl +{ + // IADLXSystem interface + ADLX_RESULT (ADLX_STD_CALL *GetHybridGraphicsType)(/* IADLXSystem* pThis, ADLX_HG_TYPE* hgType */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUs)(IADLXSystem* pThis, IADLXGPUList** ppGPUs); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXSystem* pThis, const wchar_t* interfaceId, void** ppInterface */); + ADLX_RESULT (ADLX_STD_CALL *GetDisplaysServices)(/* IADLXSystem* pThis, IADLXDisplayServices** ppDispServices */); + ADLX_RESULT (ADLX_STD_CALL *GetDesktopsServices)(/* IADLXSystem* pThis, IADLXDesktopServices** ppDeskServices */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUsChangedHandling)(/* IADLXSystem* pThis, IADLXGPUsChangedHandling** ppGPUsChangedHandling */); + ADLX_RESULT (ADLX_STD_CALL *EnableLog)(/* IADLXSystem* pThis, ADLX_LOG_DESTINATION mode, ADLX_LOG_SEVERITY severity, IADLXLog* pLogger, const wchar_t* fileName */); + ADLX_RESULT (ADLX_STD_CALL *Get3DSettingsServices)(/* IADLXSystem* pThis, IADLX3DSettingsServices** pp3DSettingsServices */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUTuningServices)(/* IADLXSystem* pThis, IADLXGPUTuningServices** ppGPUTuningServices */); + ADLX_RESULT (ADLX_STD_CALL *GetPerformanceMonitoringServices)(IADLXSystem* pThis, IADLXPerformanceMonitoringServices** ppPerformanceMonitoringServices); // Used + ADLX_RESULT (ADLX_STD_CALL *TotalSystemRAM)(/* IADLXSystem* pThis, adlx_uint* ramMB */); + ADLX_RESULT (ADLX_STD_CALL *GetI2C)(/* IADLXSystem* pThis, IADLXGPU* pGPU, IADLXI2C** ppI2C */); +} IADLXSystemVtbl; +struct IADLXSystem { const IADLXSystemVtbl *pVtbl; }; + +typedef struct IADLXGPUVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPU* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXGPU* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPU* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPU + ADLX_RESULT (ADLX_STD_CALL *VendorId)(/* IADLXGPU* pThis, const char** vendorId */); + ADLX_RESULT (ADLX_STD_CALL *ASICFamilyType)(/* IADLXGPU* pThis, ADLX_ASIC_FAMILY_TYPE* asicFamilyType */); + ADLX_RESULT (ADLX_STD_CALL *Type)(/* IADLXGPU* pThis, ADLX_GPU_TYPE* gpuType */); + ADLX_RESULT (ADLX_STD_CALL *IsExternal)(/* IADLXGPU* pThis, adlx_bool* isExternal */); + ADLX_RESULT (ADLX_STD_CALL *Name)(/* IADLXGPU* pThis, const char** gpuName */); + ADLX_RESULT (ADLX_STD_CALL *DriverPath)(/* IADLXGPU* pThis, const char** driverPath */); + ADLX_RESULT (ADLX_STD_CALL *PNPString)(/* IADLXGPU* pThis, const char** pnpString */); + ADLX_RESULT (ADLX_STD_CALL *HasDesktops)(/* IADLXGPU* pThis, adlx_bool* hasDesktops */); + ADLX_RESULT (ADLX_STD_CALL *TotalVRAM)(IADLXGPU* pThis, adlx_uint* vramMB); // Used + ADLX_RESULT (ADLX_STD_CALL *VRAMType)(/* IADLXGPU* pThis, const char** type */); + ADLX_RESULT (ADLX_STD_CALL *BIOSInfo)(/* IADLXGPU* pThis, const char** partNumber, const char** version, const char** date */); + ADLX_RESULT (ADLX_STD_CALL *DeviceId)(/* IADLXGPU* pThis, const char** deviceId */); + ADLX_RESULT (ADLX_STD_CALL *RevisionId)(/* IADLXGPU* pThis, const char** revisionId */); + ADLX_RESULT (ADLX_STD_CALL *SubSystemId)(/* IADLXGPU* pThis, const char** subSystemId */); + ADLX_RESULT (ADLX_STD_CALL *SubSystemVendorId)(/* IADLXGPU* pThis, const char** subSystemVendorId */); + ADLX_RESULT (ADLX_STD_CALL *UniqueId)(IADLXGPU* pThis, adlx_int* uniqueId); // Used +} IADLXGPUVtbl; +struct IADLXGPU { const IADLXGPUVtbl *pVtbl; }; + +typedef struct IADLXGPUListVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPUList* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXGPUList* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPUList* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXList + adlx_uint (ADLX_STD_CALL *Size)(/* IADLXGPUList* pThis */); + adlx_uint8 (ADLX_STD_CALL *Empty)(/* IADLXGPUList* pThis */); + adlx_uint (ADLX_STD_CALL *Begin)(IADLXGPUList* pThis); // Used + adlx_uint (ADLX_STD_CALL *End)(IADLXGPUList* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *At)(/* IADLXGPUList* pThis, const adlx_uint location, IADLXInterface** ppItem */); + ADLX_RESULT (ADLX_STD_CALL *Clear)(/* IADLXGPUList* pThis */); + ADLX_RESULT (ADLX_STD_CALL *Remove_Back)(/* IADLXGPUList* pThis */); + ADLX_RESULT (ADLX_STD_CALL *Add_Back)(/* IADLXGPUList* pThis, IADLXInterface* pItem */); + + //IADLXGPUList + ADLX_RESULT (ADLX_STD_CALL *At_GPUList)(IADLXGPUList* pThis, const adlx_uint location, IADLXGPU** ppItem); // Used + ADLX_RESULT (ADLX_STD_CALL *Add_Back_GPUList)(/* IADLXGPUList* pThis, IADLXGPU* pItem */); + +} IADLXGPUListVtbl; +struct IADLXGPUList { const IADLXGPUListVtbl *pVtbl; }; + +typedef struct IADLXPerformanceMonitoringServicesVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXPerformanceMonitoringServices* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXPerformanceMonitoringServices* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXPerformanceMonitoringServices* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXPerformanceMonitoringServices + ADLX_RESULT (ADLX_STD_CALL *GetSamplingIntervalRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); + ADLX_RESULT (ADLX_STD_CALL *SetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int intervalMs */); + ADLX_RESULT (ADLX_STD_CALL *GetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* intervalMs */); + ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySizeRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); + ADLX_RESULT (ADLX_STD_CALL *SetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *ClearPerformanceMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *StartPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *StopPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *GetAllMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXAllMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, adlx_int startMs, adlx_int stopMs, IADLXGPUMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetSystemMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXSystemMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetFPSHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXFPSList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentAllMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXAllMetrics** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetrics** ppMetrics); // Used + ADLX_RESULT (ADLX_STD_CALL *GetCurrentSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetrics** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentFPS)(/* IADLXPerformanceMonitoringServices* pThis, IADLXFPS** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetSupportedGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetricsSupport** ppMetricsSupported); // Used + ADLX_RESULT (ADLX_STD_CALL *GetSupportedSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetricsSupport** ppMetricsSupported */); +}IADLXPerformanceMonitoringServicesVtbl; +struct IADLXPerformanceMonitoringServices { const IADLXPerformanceMonitoringServicesVtbl *pVtbl; }; + +typedef struct IADLXGPUMetricsSupportVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetricsSupport* pThis */); + adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetricsSupport* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetricsSupport* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPUMetricsSupport + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUUsage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAMClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUHotspotTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTotalBoardPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUFanSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAM)(IADLXGPUMetricsSupport* pThis, adlx_bool* supported); // Used + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVoltage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + + ADLX_RESULT (ADLX_STD_CALL* GetGPUUsageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUHotspotTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUFanSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVoltageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUTotalBoardPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); +} IADLXGPUMetricsSupportVtbl; +struct IADLXGPUMetricsSupport { const IADLXGPUMetricsSupportVtbl *pVtbl; }; + +typedef struct IADLXGPUMetricsVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetrics* pThis */); + adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetrics* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetrics* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPUMetrics + ADLX_RESULT (ADLX_STD_CALL* TimeStamp)(/* IADLXGPUMetrics* pThis, adlx_int64* ms */); + ADLX_RESULT (ADLX_STD_CALL* GPUUsage)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUVRAMClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUHotspotTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUTotalBoardPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUFanSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUVRAM)(IADLXGPUMetrics* pThis, adlx_int* data); // Used + ADLX_RESULT (ADLX_STD_CALL* GPUVoltage)(/* IADLXGPUMetrics* pThis, adlx_int* data */); +} IADLXGPUMetricsVtbl; +struct IADLXGPUMetrics { const IADLXGPUMetricsVtbl *pVtbl; }; + +struct { + void *handle; + ADLX_RESULT (*ADLXInitialize)(adlx_uint64 version, IADLXSystem** ppSystem); + ADLX_RESULT (*ADLXInitializeWithIncompatibleDriver)(adlx_uint64 version, IADLXSystem** ppSystem); + ADLX_RESULT (*ADLXQueryVersion)(const char** version); + ADLX_RESULT (*ADLXTerminate)(); + IADLXSystem *sys; +} adlx { NULL, NULL, NULL, NULL, NULL, NULL }; +static std::mutex ggml_adlx_lock; + +extern "C" { + +int ggml_hip_mgmt_init() { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle != NULL) { + // Already initialized + return 0; + } + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + fs::path libPath = fs::path("\\Windows") / fs::path("System32") / fs::path("amdadlx64.dll"); + + adlx.handle = (void*)LoadLibraryW(libPath.wstring().c_str()); + if (adlx.handle == NULL) { + return ADLX_NOT_FOUND; + } + + adlx.ADLXInitialize = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitialize"); + adlx.ADLXInitializeWithIncompatibleDriver = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitializeWithIncompatibleDriver"); + adlx.ADLXTerminate = (ADLX_RESULT (*)()) GetProcAddress((HMODULE)(adlx.handle), "ADLXTerminate"); + adlx.ADLXQueryVersion = (ADLX_RESULT (*)(const char **version)) GetProcAddress((HMODULE)(adlx.handle), "ADLXQueryVersion"); + if (adlx.ADLXInitialize == NULL || adlx.ADLXInitializeWithIncompatibleDriver == NULL || adlx.ADLXTerminate == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in amdadlx64.dll, falling back to hip free memory reporting", __func__); + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; + return ADLX_NOT_FOUND; + } + + SetErrorMode(old_mode); + + // Aid in troubleshooting... + if (adlx.ADLXQueryVersion != NULL) { + const char *version = NULL; + ADLX_RESULT status = adlx.ADLXQueryVersion(&version); + if (ADLX_SUCCEEDED(status)) { + GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); + } + } + + ADLX_RESULT status = adlx.ADLXInitialize(ADLX_FULL_VERSION, &adlx.sys); + if (ADLX_FAILED(status)) { + // GGML_LOG_DEBUG("%s failed to initialize ADLX error=%d - attempting with incompatible driver...\n", __func__, status); + // Try with the incompatible driver + status = adlx.ADLXInitializeWithIncompatibleDriver(ADLX_FULL_VERSION, &adlx.sys); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s failed to initialize ADLX error=%d\n", __func__, status); + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; + adlx.sys = NULL; + return status; + } + // GGML_LOG_DEBUG("%s initialized ADLX with incpomatible driver\n", __func__); + } + return ADLX_OK; +} + +void ggml_hip_mgmt_release() { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle == NULL) { + // Already free + return; + } + ADLX_RESULT status = adlx.ADLXTerminate(); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s failed to terminate Adlx %d\n", __func__, status); + // Unload anyway... + } + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; +} + +#define adlx_gdm_cleanup \ + if (gpuMetricsSupport != NULL) gpuMetricsSupport->pVtbl->Release(gpuMetricsSupport); \ + if (gpuMetrics != NULL) gpuMetrics->pVtbl->Release(gpuMetrics); \ + if (perfMonitoringServices != NULL) perfMonitoringServices->pVtbl->Release(perfMonitoringServices); \ + if (gpus != NULL) gpus->pVtbl->Release(gpus); \ + if (gpu != NULL) gpu->pVtbl->Release(gpu) + +int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle == NULL) { + GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); + return ADLX_ADL_INIT_ERROR; + } + IADLXGPUMetricsSupport *gpuMetricsSupport = NULL; + IADLXPerformanceMonitoringServices *perfMonitoringServices = NULL; + IADLXGPUList* gpus = NULL; + IADLXGPU* gpu = NULL; + IADLXGPUMetrics *gpuMetrics = NULL; + ADLX_RESULT status; + // The "UniqueID" exposed in ADLX is the PCI Bus and Device IDs + adlx_int target = (pci_bus_id << 8) | (pci_device_id & 0xff); + + status = adlx.sys->pVtbl->GetPerformanceMonitoringServices(adlx.sys, &perfMonitoringServices); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetPerformanceMonitoringServices failed %d\n", __func__, status); + return status; + } + + status = adlx.sys->pVtbl->GetGPUs(adlx.sys, &gpus); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetGPUs failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + // Get GPU list + for (adlx_uint crt = gpus->pVtbl->Begin(gpus); crt != gpus->pVtbl->End(gpus); ++crt) + { + status = gpus->pVtbl->At_GPUList(gpus, crt, &gpu); + if (ADLX_FAILED(status)) + { + GGML_LOG_INFO("%s %d] At_GPUList failed %d\n", __func__, crt, status); + continue; + } + adlx_int id; + status = gpu->pVtbl->UniqueId(gpu, &id); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s %d] UniqueId lookup failed %d\n", __func__, crt, status); + gpu->pVtbl->Release(gpu); + gpu = NULL; + continue; + } + if (id != target) { + GGML_LOG_DEBUG("%s %d] GPU UniqueId: %x does not match target %02x %02x\n", __func__, crt, id, pci_bus_id, pci_device_id); + gpu->pVtbl->Release(gpu); + gpu = NULL; + continue; + } + // Any failures at this point should cause a fall-back to other APIs + status = perfMonitoringServices->pVtbl->GetSupportedGPUMetrics(perfMonitoringServices, gpu, &gpuMetricsSupport); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetSupportedGPUMetrics failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + status = perfMonitoringServices->pVtbl->GetCurrentGPUMetrics(perfMonitoringServices, gpu, &gpuMetrics); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetCurrentGPUMetrics failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_bool supported = false; + status = gpuMetricsSupport->pVtbl->IsSupportedGPUVRAM(gpuMetricsSupport, &supported); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s IsSupportedGPUVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_uint totalVRAM = 0; + status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s TotalVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_int usedVRAM = 0; + status = gpuMetrics->pVtbl->GPUVRAM(gpuMetrics, &usedVRAM); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GPUVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + *total = size_t(totalVRAM) * 1024 * 1024; + *free = size_t(totalVRAM-usedVRAM) * 1024 * 1024; + + adlx_gdm_cleanup; + return ADLX_OK; + } + adlx_gdm_cleanup; + return ADLX_NOT_FOUND; +} + +} // extern "C" + +#else // #ifdef _WIN32 + +extern "C" { + +// TODO Linux implementation of accurate VRAM reporting +int ggml_hip_mgmt_init() { + return -1; +} +void ggml_hip_mgmt_release() {} +int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { + return -1; +} + +} // extern "C" + +#endif // #ifdef _WIN32 \ No newline at end of file diff --git a/ml/backend/ggml/ggml/src/mem_nvml.cpp b/ml/backend/ggml/ggml/src/mem_nvml.cpp new file mode 100644 index 000000000..aa05e9dc1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/mem_nvml.cpp @@ -0,0 +1,172 @@ +// NVIDIA Management Library (NVML) +// +// https://developer.nvidia.com/management-library-nvml +// +// This library provides accurate VRAM reporting for NVIDIA GPUs, particularly +// on Windows, where the cuda library provides inaccurate VRAM usage metrics. The +// runtime DLL is installed with every driver on Windows, and most Linux +// systems, and the headers are included in the standard CUDA SDK install. As +// such, we can include the header here to simplify the code. + + +#include "ggml-impl.h" +#include +#include + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#else +# include +# include +#endif + +namespace fs = std::filesystem; + +// Minimal definitions to avoid including the nvml.h header +typedef enum nvmlReturn_enum +{ + // cppcheck-suppress * + NVML_SUCCESS = 0, //!< The operation was successful + NVML_ERROR_UNINITIALIZED = 1, //!< NVML was not first initialized with nvmlInit() + NVML_ERROR_INVALID_ARGUMENT = 2, //!< A supplied argument is invalid + NVML_ERROR_NOT_SUPPORTED = 3, //!< The requested operation is not available on target device + NVML_ERROR_NO_PERMISSION = 4, //!< The current user does not have permission for operation + NVML_ERROR_ALREADY_INITIALIZED = 5, //!< Deprecated: Multiple initializations are now allowed through ref counting + NVML_ERROR_NOT_FOUND = 6, //!< A query to find an object was unsuccessful + NVML_ERROR_INSUFFICIENT_SIZE = 7, //!< An input argument is not large enough + NVML_ERROR_INSUFFICIENT_POWER = 8, //!< A device's external power cables are not properly attached + NVML_ERROR_DRIVER_NOT_LOADED = 9, //!< NVIDIA driver is not loaded + NVML_ERROR_TIMEOUT = 10, //!< User provided timeout passed + NVML_ERROR_IRQ_ISSUE = 11, //!< NVIDIA Kernel detected an interrupt issue with a GPU + NVML_ERROR_LIBRARY_NOT_FOUND = 12, //!< NVML Shared Library couldn't be found or loaded + NVML_ERROR_FUNCTION_NOT_FOUND = 13, //!< Local version of NVML doesn't implement this function + NVML_ERROR_CORRUPTED_INFOROM = 14, //!< infoROM is corrupted + NVML_ERROR_GPU_IS_LOST = 15, //!< The GPU has fallen off the bus or has otherwise become inaccessible + NVML_ERROR_RESET_REQUIRED = 16, //!< The GPU requires a reset before it can be used again + NVML_ERROR_OPERATING_SYSTEM = 17, //!< The GPU control device has been blocked by the operating system/cgroups + NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18, //!< RM detects a driver/library version mismatch + NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use + NVML_ERROR_MEMORY = 20, //!< Insufficient memory + NVML_ERROR_NO_DATA = 21, //!< No data + NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled + NVML_ERROR_INSUFFICIENT_RESOURCES = 23, //!< Ran out of critical resources, other than memory + NVML_ERROR_FREQ_NOT_SUPPORTED = 24, //!< Ran out of critical resources, other than memory + NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25, //!< The provided version is invalid/unsupported + NVML_ERROR_DEPRECATED = 26, //!< The requested functionality has been deprecated + NVML_ERROR_NOT_READY = 27, //!< The system is not ready for the request + NVML_ERROR_GPU_NOT_FOUND = 28, //!< No GPUs were found + NVML_ERROR_INVALID_STATE = 29, //!< Resource not in correct state to perform requested operation + NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred +} nvmlReturn_t; +typedef struct nvmlDevice_st* nvmlDevice_t; +typedef struct nvmlMemory_st +{ + unsigned long long total; //!< Total physical device memory (in bytes) + unsigned long long free; //!< Unallocated device memory (in bytes) + unsigned long long used; //!< Sum of Reserved and Allocated device memory (in bytes). + //!< Note that the driver/GPU always sets aside a small amount of memory for bookkeeping +} nvmlMemory_t; +// end nvml.h definitions + +struct { + void *handle; + nvmlReturn_t (*nvmlInit_v2)(void); + nvmlReturn_t (*nvmlShutdown)(void); + nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); + nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); +} nvml { NULL, NULL, NULL, NULL, NULL }; +static std::mutex ggml_nvml_lock; + +extern "C" { + +int ggml_nvml_init() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle != NULL) { + // Already initialized + return 0; + } +#ifdef _WIN32 + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + fs::path libPath[2]; + const char * programDir = std::getenv("ProgramW6432"); + if (programDir == NULL) { + libPath[0] = fs::path("Program Files") / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); + } else { + libPath[0] = fs::path(programDir) / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); + } + libPath[1] = fs::path("\\Windows") / fs::path("System32") / fs::path("NVML.dll"); + + for (int i = 0; i < 2; i++) { + nvml.handle = (void*)LoadLibraryW(libPath[i].wstring().c_str()); + if (nvml.handle != NULL) { + break; + } + } + if (nvml.handle == NULL) { + return NVML_ERROR_NOT_FOUND; + } + + nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlInit_v2"); + nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); + nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); + nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); + if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; + return NVML_ERROR_NOT_FOUND; + } + + SetErrorMode(old_mode); + +#else + // Not currently wired up on Linux + return NVML_ERROR_NOT_SUPPORTED; +#endif + int status = nvml.nvmlInit_v2(); + return NVML_SUCCESS; +} + +void ggml_nvml_release() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle == NULL) { + // Already free + return; + } + nvmlReturn_enum status = nvml.nvmlShutdown(); + if (status != NVML_SUCCESS) { + GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status); + } +#ifdef _WIN32 + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; +#else + // Not currently wired up on Linux +#endif +} + +int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle == NULL) { + return NVML_ERROR_UNINITIALIZED; + } + nvmlDevice_t device; + auto status = nvml.nvmlDeviceGetHandleByUUID(uuid, &device); + if (status != NVML_SUCCESS) { + return status; + } + nvmlMemory_t memInfo = {0}; + status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); + if (status == NVML_SUCCESS) { + *free = memInfo.free; + *total = memInfo.total; + } + return status; +} + +} \ No newline at end of file diff --git a/ml/device.go b/ml/device.go new file mode 100644 index 000000000..6569d87bb --- /dev/null +++ b/ml/device.go @@ -0,0 +1,338 @@ +package ml + +import ( + "context" + "encoding/binary" + "fmt" + "hash/maphash" + "log/slog" + "slices" + "sort" + "strconv" + "strings" + + "github.com/ollama/ollama/format" +) + +// GPULayers is a set of layers to be allocated on a single GPU +type GPULayers struct { + DeviceID + + // Layers is a set of layer indicies to load + Layers []int +} + +func (g GPULayers) String() string { + if len(g.Layers) == 0 { + return "" + } + + slices.Sort(g.Layers) + + contiguous := true + base := g.Layers[0] + for i := range g.Layers { + if g.Layers[i] != base+i { + contiguous = false + break + } + } + + if contiguous { + return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) + } else { + return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) + } +} + +// GPULayersList is a set of layer allocations across multiple GPUs +type GPULayersList []GPULayers + +func (l GPULayersList) String() string { + if l.Sum() > 0 { + return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) + } else { + return fmt.Sprintf("%v", []GPULayers(l)) + } +} + +// Sum is the total number of layers assigned across all GPUs +func (l GPULayersList) Sum() int { + var sum int + + for _, g := range l { + sum += len(g.Layers) + } + + return sum +} + +var h maphash.Hash + +// Hash is an identifier of this layer assignment +func (l GPULayersList) Hash() uint64 { + h.Reset() + for _, g := range l { + if len(g.Layers) > 0 { + h.WriteString(g.ID + g.Library) + for _, l := range g.Layers { + binary.Write(&h, binary.NativeEndian, int64(l)) + } + } + } + + return h.Sum64() +} + +// ErrNoMem is returned when panicing due to insufficient memory. It includes +// the attempted memory allocation. +type ErrNoMem struct { + BackendMemory +} + +func (e ErrNoMem) Error() string { + return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) +} + +// Minimal unique device identification +type DeviceID struct { + // ID is an identifier for the device for matching with system + // management libraries. The ID is only unique for other devices + // using the same Library. + // This ID represents a "post filtered" view of the enumerated devices + // if the ID is numeric + ID string `json:"id"` + + // Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.) + Library string `json:"backend,omitempty"` +} + +// DeviceMemory provides a breakdown of the memory needed +// per device, such as a CPU or GPU. +type DeviceMemory struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string + + // Weights is the per-layer memory needed for the model weights. + Weights []uint64 + + // Cache is the per-layer memory needed for the KV cache. + Cache []uint64 + + // Graph is the size of the compute graph. It is not per-layer. + Graph uint64 +} + +func sumMemory(mem []uint64) uint64 { + var sum uint64 + + for _, m := range mem { + sum += m + } + + return sum +} + +// Size returns the total size of the memory required by this device +func (m DeviceMemory) Size() uint64 { + return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph +} + +func memoryPresent(mem []uint64) bool { + return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 }) +} + +func (m DeviceMemory) LogValue() slog.Value { + var attrs []slog.Attr + if memoryPresent(m.Weights) { + attrs = append(attrs, slog.Any("Weights", m.Weights)) + } + + if memoryPresent(m.Cache) { + attrs = append(attrs, slog.Any("Cache", m.Cache)) + } + + if m.Graph != 0 { + attrs = append(attrs, slog.Any("Graph", m.Graph)) + } + + if len(attrs) > 0 && m.ID != "" { + attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...) + } + + return slog.GroupValue(attrs...) +} + +// BackendMemory provides the amount of memory required to load the model +// per device based on the BackendParams. In some cases, not all required +// allocations will be known at this point. However, the size of the most recent +// allocation is guaranteed to be provided so that if it failed, the caller can +// accommodate that to make forward progress. +type BackendMemory struct { + // InputWeights are always located on the CPU and cannot be moved + InputWeights uint64 + + // CPU model components are located in system memory. This does not + // include unified memory allocated through the GPU. + CPU DeviceMemory + + // GPU model components are located on one or more GPUs. + GPUs []DeviceMemory +} + +func (m BackendMemory) LogValue() slog.Value { + var attrs []slog.Attr + if m.InputWeights != 0 { + attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) + } + + attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU)) + for _, g := range m.GPUs { + attrs = append(attrs, slog.Any(g.Name, g)) + } + + return slog.GroupValue(attrs...) +} + +// Log prints a high level summary of the memory +func (m BackendMemory) Log(level slog.Level) { + var total uint64 + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := sumMemory(m.CPU.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := gpu.Graph; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.CPU.Graph; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + if total > 0 { + slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) + } +} + +type DeviceInfo struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string `json:"name"` + + // Description is the longer user-friendly identification of the device + Description string `json:"description"` + + // FilterID is populated with the unfiltered device ID if a numeric ID is used + // so the device can be included. + FilteredID string `json:"filtered_id,omitempty"` + + // Integrated is set true for integrated GPUs, false for Discrete GPUs + Integrated bool `json:"integration,omitempty"` + + // PCIID is the bus, device and domain ID of the device for deduplication + // when discovered by multiple backends + PCIID string `json:"pci_id,omitempty"` + + // TotalMemory is the total amount of memory the device can use for loading models + TotalMemory uint64 `json:"total_memory"` + + // FreeMemory is the amount of memory currently available on the device for loading models + FreeMemory uint64 `json:"free_memory,omitempty"` + + // ComputeMajor is the major version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMajor int + + // ComputeMinor is the minor version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMinor int + + // Driver Information + DriverMajor int `json:"driver_major,omitempty"` + DriverMinor int `json:"driver_minor,omitempty"` + + // Where backends were loaded from + LibraryPath []string +} + +func (d DeviceInfo) Compute() string { + // AMD gfx is encoded into the major minor in hex form + if strings.EqualFold(d.Library, "ROCm") { + return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor) + } + return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor) +} + +func (d DeviceInfo) Driver() string { + return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor) +} + +type DeviceComparison int + +const ( + UniqueDevice DeviceComparison = iota + SameBackendDevice // The device is the same, and the library/backend is the same + DuplicateDevice // The same physical device but different library/backend (overlapping device) +) + +func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison { + if a.PCIID != b.PCIID { + return UniqueDevice + } + if a.Library == b.Library { + return SameBackendDevice + } + return DuplicateDevice +} + +// For a SameBackendDevice, return true if b is better than a +// e.g. newer GPU library version +func (a DeviceInfo) IsBetter(b DeviceInfo) bool { + aLib := a.LibraryPath[len(a.LibraryPath)-1] + bLib := b.LibraryPath[len(b.LibraryPath)-1] + if aLib == bLib { + return false + } + aLibSplit := strings.SplitN(aLib, "_", 2) + bLibSplit := strings.SplitN(bLib, "_", 2) + if len(aLibSplit) < 2 || len(bLibSplit) < 2 { + return false + } + if aLibSplit[0] != bLibSplit[0] { + slog.Debug("unexpected libraries", "a", aLib, "b", bLib) + return false + } + if aLibSplit[1] == bLibSplit[1] { + return false + } + cmp := []string{aLibSplit[1], bLibSplit[1]} + sort.Sort(sort.Reverse(sort.StringSlice(cmp))) + return cmp[0] == bLibSplit[1] +} diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 21b4a28ae..94dbde0b0 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache } func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + ctx.Forward(query) if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) @@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) } + ctx.Forward(key, value) if cache != nil { cache.Put(ctx, key, value) } diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go new file mode 100644 index 000000000..63b63b3af --- /dev/null +++ b/ml/nn/pooling/pooling.go @@ -0,0 +1,42 @@ +package pooling + +import ( + "github.com/ollama/ollama/ml" +) + +type Type uint32 + +const ( + TypeNone Type = iota + TypeMean + TypeCLS + TypeLast +) + +func (t Type) String() string { + switch t { + case TypeMean: + return "Mean" + case TypeCLS: + return "CLS" + case TypeLast: + return "Last" + default: + return "Unknown" + } +} + +func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + switch t { + case TypeMean: + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) + return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + case TypeCLS: + return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + case TypeLast: + hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) + return hiddenStates + default: + panic("unknown pooling type") + } +} diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go new file mode 100644 index 000000000..e27727462 --- /dev/null +++ b/ml/nn/pooling/pooling_test.go @@ -0,0 +1,64 @@ +package pooling_test + +import ( + "bytes" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/ml/nn/pooling" +) + +func setup(tb testing.TB, n int) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := fsggml.WriteGGUF(f, fsggml.KV{ + "general.architecture": "test", + "test.block_count": uint32(1), + }, []*fsggml.Tensor{ + {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))}, + }); err != nil { + tb.Fatal(err) + } + + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true}) + if err != nil { + tb.Fatal(err) + } + + return b +} + +func TestForward(t *testing.T) { + cases := map[pooling.Type][]float32{ + pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11}, + pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7}, + pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15}, + } + for typ, want := range cases { + t.Run(typ.String(), func(t *testing.T) { + b := setup(t, 99) + defer b.Close() + + ctx := b.NewContext() + defer ctx.Close() + + tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2) + tt = typ.Forward(ctx, tt) + + ctx.Forward(tt).Compute(tt) + if diff := cmp.Diff(want, tt.Floats()); diff != "" { + t.Error(diff) + } + }) + } +} diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index e4083dfce..3d51f70e8 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -2,10 +2,10 @@ package model import ( "cmp" - "context" "fmt" "iter" "log/slog" + "slices" "strings" "github.com/dlclark/regexp2" @@ -14,16 +14,28 @@ import ( ) type BytePairEncoding struct { - pre *regexp2.Regexp - vocab *Vocabulary + vocab *Vocabulary + regexps []*regexp2.Regexp } var _ TextProcessor = (*BytePairEncoding)(nil) -func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { +func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { + if len(pretokenizers) == 0 { + // set default byte-level pretokenizer if none provided, e.g. + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44 + pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} + } + return BytePairEncoding{ - pre: regexp2.MustCompile(pre, regexp2.None), vocab: vocab, + regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { + for _, p := range pretokenizers { + if !yield(regexp2.MustCompile(p, regexp2.RE2)) { + return + } + } + }), } } @@ -36,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool { } func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { - return func(yield func(string) bool) { - for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { - if !yield(m.String()) { - break + parts := []string{s} + for _, re := range bpe.regexps { + parts = slices.Collect(func(yield func(string) bool) { + for _, part := range parts { + r := []rune(part) + var offset int + for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) { + if offset-m.Index != 0 { + if !yield(string(r[:m.Index])) { + return + } + } + + if !yield(m.String()) { + return + } + + offset = m.Index + m.Length + } + + if offset < len(r) { + if !yield(string(r[offset:])) { + return + } + } } - } + }) } + + return slices.Values(parts) } // fragment is a string fragment and their corresponding token IDs @@ -202,12 +237,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) - if addSpecial && len(ids) > 0 { ids = bpe.vocab.addSpecials(ids) } + logutil.Trace("encoded", "string", s, "ids", ids) return ids, nil } @@ -243,6 +277,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) + logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) return sb.String(), nil } diff --git a/model/bytepairencoding_test.go b/model/bytepairencoding_test.go index 71947be99..39e5ab452 100644 --- a/model/bytepairencoding_test.go +++ b/model/bytepairencoding_test.go @@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding { } return NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &Vocabulary{ Values: tokens, Types: types, Merges: merges, }, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ) } @@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) } } + +func TestSplit(t *testing.T) { + cases := []struct { + name string + patterns, + want []string + }{ + { + name: "default", + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, + }, + { + name: "unicode", + patterns: []string{ + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, + }, + { + name: "individual digits", + patterns: []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer := NewBytePairEncoding(nil, tt.patterns...) + if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + }) + } +} diff --git a/model/input/input.go b/model/input/input.go index bd9b53ec6..35dc41b35 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -54,10 +54,9 @@ type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. Inputs ml.Tensor - // Multimodal is a set of multimodal embeddings previously created by - // EncodeMultimodal, along with an index into Inputs. Unused for text-only - // models or for batches without multimodal elements. - Multimodal []MultimodalIndex + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs ml.Tensor // Positions is the position for each Input, relative to its sequence. Equal // in length to Inputs. @@ -66,7 +65,8 @@ type Batch struct { // Sequences is the sequence for each Input. Equal in length to Inputs. Sequences []int - // Outputs are the set of indicies into Inputs for which output data should - // be returned. - Outputs []int32 + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. + Multimodal []MultimodalIndex } diff --git a/model/model.go b/model/model.go index d0fe26d7e..0af16da80 100644 --- a/model/model.go +++ b/model/model.go @@ -1,7 +1,6 @@ package model import ( - "context" "errors" "fmt" _ "image/jpeg" @@ -22,10 +21,15 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model/input" ) -var ErrNoVisionModel = errors.New("this model is missing data required for image input") +var ( + ErrNoVisionModel = errors.New("this model is missing data required for image input") + ErrUnsupportedModel = errors.New("model not supported") + ErrUnsupportedTokenizer = errors.New("tokenizer not supported") +) // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { @@ -64,7 +68,7 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize([]input.Input) ([]input.Input, error) + PostTokenize([]*input.Input) ([]*input.Input, error) } // Base implements the common fields and methods for all models @@ -104,19 +108,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { return nil, err } - arch := b.Config().Architecture() - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - - m, err := f(b.Config()) + m, err := modelForArch(b.Config()) if err != nil { return nil, err } base := Base{b: b, config: m.Config()} - v := reflect.ValueOf(m) v.Elem().Set(populateFields(base, v.Elem())) return m, nil @@ -128,30 +125,38 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } - return getTextProcessor(meta.KV()) -} -func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { - arch := kv.Architecture() - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - m, err := f(kv) + m, err := modelForArch(meta.KV()) if err != nil { return nil, err } + tp, ok := m.(TextProcessor) if !ok { - return nil, fmt.Errorf("%v is not a TextProcessor", m) + return nil, ErrUnsupportedTokenizer } return tp, nil } +func modelForArch(c fs.Config) (Model, error) { + arch := c.Architecture() + if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { + arch = arch + "_embed" + } + + f, ok := models[arch] + if !ok { + return nil, ErrUnsupportedModel + } + + return f(c) +} + func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() @@ -167,38 +172,47 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { // make a copy tagsCopy := tags if tag := t.Field(i).Tag.Get("gguf"); tag != "" { - tagsCopy = append(tagsCopy, ParseTags(tag)) + tagsCopy = append(tagsCopy, parseTag(tag)) } if tt == reflect.TypeOf((*Base)(nil)).Elem() { vv.Set(reflect.ValueOf(base)) } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { - var fn func([]Tag) [][]string - fn = func(tags []Tag) (names [][]string) { + var fn func([]Tag, string, string) [][]string + fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { if len(tags) > 0 { - localNames := []string{tags[0].Name} - localNames = append(localNames, tags[0].Alternate...) - - for _, localName := range localNames { - fullName := []string{localName} - nested := fn(tags[1:]) - if len(nested) > 0 { - for _, rest := range nested { - names = append(names, append(fullName, rest...)) + var names []string + if tags[0].name != "" { + for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) { + names = append(names, prefix+n+suffix) + } + } + childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix) + if len(names) == 0 { + // current tag has no name, use child names only + fullNames = append(fullNames, childNames...) + } else if len(childNames) == 0 { + // current tag has names but no children, create branches for each name + for _, name := range names { + fullNames = append(fullNames, []string{name}) + } + } else { + // merge each name with each child + for _, name := range names { + for _, childName := range childNames { + fullNames = append(fullNames, append([]string{name}, childName...)) } - } else { - names = append(names, fullName) } } } - return names + return fullNames } - names := fn(tagsCopy) + names := fn(tagsCopy, "", "") for _, name := range names { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { - slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor) + logutil.Trace("found tensor", "", tensor) vv.Set(reflect.ValueOf(tensor)) break } @@ -209,9 +223,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { for i := range vv.Len() { vvv := vv.Index(i) if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) } else { - vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) } } } @@ -239,7 +253,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { vv = vv.Elem() } - vv = vv.Elem() + vv = reflect.Indirect(vv) if v.IsNil() { vv = reflect.New(v.Type().Elem()).Elem() } @@ -250,18 +264,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { } type Tag struct { - Name string - Alternate []string + name, + // prefix and suffix are applied to child tags + prefix, + suffix string + alternatives []string } -func ParseTags(s string) (tag Tag) { +func parseTag(s string) (tag Tag) { parts := strings.Split(s, ",") if len(parts) > 0 { - tag.Name = parts[0] + tag.name = parts[0] for _, part := range parts[1:] { - if value, ok := strings.CutPrefix(part, "alt:"); ok { - tag.Alternate = append(tag.Alternate, value) + if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { + // elevate alternative to primary if no primary given + tag.name = value + slog.Warn("gguf tag has alt: but no primary name", "tag", s) + } else if ok { + tag.alternatives = append(tag.alternatives, value) + } + if value, ok := strings.CutPrefix(part, "pre:"); ok { + tag.prefix = value + } + if value, ok := strings.CutPrefix(part, "suf:"); ok { + tag.suffix = value } } } @@ -278,7 +305,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { if len(batch.Positions) != len(batch.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) } @@ -287,8 +314,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten return nil, errors.New("batch size cannot be less than 1") } - batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs)) - cache := m.Config().Cache if cache != nil { err := cache.StartForward(ctx, batch, false) @@ -302,7 +327,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten return nil, err } - ctx.Forward(t).Compute(t) + ctx.Forward(t) return t, nil } diff --git a/model/model_test.go b/model/model_test.go index 020f9ffbd..f6d75b230 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -1,9 +1,9 @@ package model import ( + "errors" "reflect" "slices" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -23,14 +22,14 @@ func TestParseTags(t *testing.T) { { value: "output", want: Tag{ - Name: "output", + name: "output", }, }, { value: "output,alt:token_embd", want: Tag{ - Name: "output", - Alternate: []string{ + name: "output", + alternatives: []string{ "token_embd", }, }, @@ -39,8 +38,8 @@ func TestParseTags(t *testing.T) { for _, tt := range cases { t.Run(tt.value, func(t *testing.T) { - got := ParseTags(tt.value) - if diff := cmp.Diff(tt.want, got); diff != "" { + got := parseTag(tt.value) + if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" { t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) } }) @@ -126,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { Input *nn.Embedding `gguf:"input"` Output *nn.Linear `gguf:"output,alt:input"` Nested *nested `gguf:"nested"` + Tensor ml.Tensor `gguf:"leaf,alt:tensor"` } var m fakeModel @@ -134,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { names: []string{ "input.weight", "nested.b.weight", + "leaf", }, }}, v.Elem())) @@ -143,44 +144,115 @@ func TestPopulateFieldsAlternateName(t *testing.T) { Nested: &nested{ Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}}, }, + Tensor: &fakeTensor{Name: "leaf"}, }, m); diff != "" { t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) } } -func TestGetTextProcessor(t *testing.T) { - tp, err := getTextProcessor(fsggml.KV{}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "unsupported model architecture") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") +func TestPopulateFieldsPrefixSuffixName(t *testing.T) { + type fakeBlock struct { + A *nn.Linear `gguf:"a"` + B *nn.Linear `gguf:",pre:b_"` + C *nn.Linear `gguf:",suf:_c"` + XY *nn.Linear `gguf:",pre:x_,suf:_y"` } - models["dummy"] = func(fs.Config) (Model, error) { - return notTextProcessorModel{}, nil + type fakeModel struct { + Blocks []fakeBlock `gguf:"blk"` } - tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "not a TextProcessor") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") + + m := fakeModel{ + Blocks: make([]fakeBlock, 2), + } + v := reflect.ValueOf(&m) + v.Elem().Set(populateFields(Base{b: &fakeBackend{ + names: []string{ + "blk.0.a.weight", + "blk.0.b_weight", + "blk.0.b_bias", + "blk.0.weight_c", + "blk.0.x_weight_y", + "blk.1.a.weight", + "blk.1.b_weight", + "blk.1.b_bias", + "blk.1.weight_c", + "blk.1.x_weight_y", + }, + }}, v.Elem())) + + if diff := cmp.Diff(fakeModel{ + Blocks: []fakeBlock{ + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}}, + }, + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}}, + }, + }, + }, m); diff != "" { + t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) } } -type notTextProcessorModel struct{} +func TestModelForArch(t *testing.T) { + type fakeModel struct { + Model + } -func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) { - panic("unimplemented") -} + type fakeEmbeddingModel struct { + Model + } -func (notTextProcessorModel) Backend() ml.Backend { - panic("unimplemented") -} + models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil } + models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil } -func (notTextProcessorModel) Config() config { - panic("unimplemented") + cases := []struct { + name string + config fs.Config + want any + err error + }{ + { + name: "model", + config: fsggml.KV{ + "general.architecture": "model", + }, + want: fakeModel{}, + }, + { + name: "embedding", + config: fsggml.KV{ + "general.architecture": "model", + "model.pooling_type": uint32(1), + }, + want: fakeEmbeddingModel{}, + }, + { + name: "unsupported", + config: fsggml.KV{ + "general.architecture": "unsupported", + }, + err: ErrUnsupportedModel, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := modelForArch(tt.config) + if !errors.Is(err, tt.err) { + t.Fatal(err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff) + } + }) + } } diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go new file mode 100644 index 000000000..166c11e13 --- /dev/null +++ b/model/models/bert/embed.go @@ -0,0 +1,181 @@ +package bert + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TypeEmbedding *nn.Embedding `gguf:"token_types"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"` + + Layers []EncoderLayer `gguf:"blk"` + + Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)))) + hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) + } + + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + if m.normalize { + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + } + + return hiddenStates, nil +} + +type EncoderLayer struct { + *Attention + AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"` + + *MLP + MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"` +} + +func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + // Attention + residual := hiddenStates + hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // MLP + residual = hiddenStates + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + + return hiddenStates +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"` + + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"` + + Value *nn.Linear `gguf:"attn_v"` + + Output *nn.Linear `gguf:"attn_output"` +} + +func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := a.Query.Forward(ctx, hiddenStates) + if a.QueryNorm != nil { + query = a.QueryNorm.Forward(ctx, query, opts.eps) + } + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + + key := a.Key.Forward(ctx, hiddenStates) + if a.KeyNorm != nil { + key = a.KeyNorm.Forward(ctx, key, opts.eps) + } + key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + value := a.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + return a.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type Options struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int + poolingType pooling.Type + eps float32 + normalize bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +func New(c fs.Config) (model.Model, error) { + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model", "bert") { + case "bert": + processor = model.NewWordPiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.cls_token_id"), + c.Uint("tokenizer.ggml.bos_token_id"), + )), + }, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), + EOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.separator_token_id"), + //nolint:misspell + // NOTE: "seperator_token_id" is a typo in model metadata but we need to + // support it for compatibility. + c.Uint("tokenizer.ggml.seperator_token_id"), + c.Uint("tokenizer.ggml.eos_token_id"), + )), + }, + }, + ) + default: + return nil, model.ErrUnsupportedTokenizer + } + + return &Model{ + TextProcessor: processor, + Layers: make([]EncoderLayer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: pooling.Type(c.Uint("pooling_type")), + normalize: c.Bool("normalize_embeddings", true), + }, + }, nil +} + +func init() { + model.Register("bert", New) + model.Register("bert_embed", New) +} diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go new file mode 100644 index 000000000..7b88711ba --- /dev/null +++ b/model/models/deepseek2/model.go @@ -0,0 +1,324 @@ +package deepseek2 + +// uses deepseek 2 architecture but written based on deepseek 3 model + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + numExpertsUsed int + numExperts int + normTopKProb bool + routedScalingFactor float32 + + kvLoraRank, + qkNopeHeadDim, + qkRopeHeadDim, + kqNopeHeadDim, + qkHeadDim int + qLoraRank int + vHeadDim int + + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength, + originalContextLength int + + eps, + ropeBase, + ropeScale float32 + kqScale float64 +} + +func (o Options) RoPEOptions() []func(*rope.Options) { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + return []func(*rope.Options){ + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + } +} + +type Attention struct { + Q *nn.Linear `gguf:"attn_q"` + + QA *nn.Linear `gguf:"attn_q_a"` + QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"` + QB *nn.Linear `gguf:"attn_q_b"` + + KVA *nn.Linear `gguf:"attn_kv_a_mqa"` + KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` + KVB *nn.Linear `gguf:"attn_kv_b"` + + Output *nn.Linear `gguf:"attn_out,alt:attn_output"` +} + +func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + seqLength := hiddenStates.Dim(1) + + var query ml.Tensor + if opts.qLoraRank == 0 { // nil { + query = attn.Q.Forward(ctx, hiddenStates) + } else { + query = attn.QA.Forward(ctx, hiddenStates) + query = attn.QANorm.Forward(ctx, query, opts.eps) + query = attn.QB.Forward(ctx, query) + } + + query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) + + qPass := query.View(ctx, 0, + opts.qkNopeHeadDim, query.Stride(1), + query.Dim(1), query.Stride(2), + query.Dim(2)) + + qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0), + opts.qkRopeHeadDim, query.Stride(1), + query.Dim(1), query.Stride(2), + query.Dim(2)) + + compressedKV := attn.KVA.Forward(ctx, hiddenStates) + + kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1)) + kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), + opts.qkRopeHeadDim, compressedKV.Stride(1), + 1, compressedKV.Stride(1), + compressedKV.Dim(1)) + + kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) + kPass = attn.KVB.Forward(ctx, kPass) + + kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) + kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2)) + value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0), + opts.vHeadDim, kv.Stride(1), + kv.Dim(1), kv.Stride(2), + kv.Dim(2)).Contiguous(ctx) + + qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + + kRot = kRot.Repeat(ctx, 1, qPass.Dim(1)) + + query = qRot.Concat(ctx, qPass, 0) + key := kRot.Concat(ctx, kPass, 0) + + attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) + return attn.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` + SharedExpert *dense `gguf:",suf:_shexp"` + ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` +} + +func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = hiddenStates.SILU(ctx, upStates) + + experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) + experts = experts.Mul(ctx, topKWeights) + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + return nextStates +} + +func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor { + scores = scores.Add(ctx, moe.ExpProbsBias) + topKIndices := scores.TopK(ctx, opts.numExpertsUsed) + return topKIndices +} + +func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + residuals := hiddenStates + + routerLogits := moe.Router.Forward(ctx, hiddenStates) + scores := routerLogits.Sigmoid(ctx) + topKIndices := moe.topKIndices(ctx, scores, opts) + topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices) + + if opts.normTopKProb { + topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx)) + topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor)) + hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts) + sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts) + + hiddenStates = hiddenStates.Add(ctx, sharedExpertResult) + return hiddenStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP MLP +} + +func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + return hiddenStates +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *Options +} + +func New(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + + firstDenseLayerIndex := int(c.Uint("leading_dense_block_count")) + for i := range layers { + if i < firstDenseLayerIndex { + layers[i].MLP = &dense{} + } else { + layers[i].MLP = &sparse{} + } + } + + mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor")))) + kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length"))) + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + // Split regex into multiple parts (according to DeepSeek3's regex) + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + ), + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("expert_weights_norm", true), + + qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal, + kvLoraRank: int(c.Uint("attention.kv_lora_rank")), + qkHeadDim: int(c.Uint("attention.key_length")), + vHeadDim: int(c.Uint("attention.value_length")), + qkRopeHeadDim: int(c.Uint("rope.dimension_count")), + qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + + routedScalingFactor: c.Float("expert_weights_scale"), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + + kqScale: kqScale, + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("deepseek2", New) +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index e621d03ae..2b16dc62e 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -24,7 +24,7 @@ type Options struct { type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -40,7 +40,7 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) { attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base", 10000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), }, @@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { @@ -138,7 +138,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) @@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go new file mode 100644 index 000000000..525547767 --- /dev/null +++ b/model/models/gemma3/embed.go @@ -0,0 +1,62 @@ +package gemma3 + +import ( + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type embedModel struct { + model.Base + model.SentencePiece + + *TextModel + poolingType pooling.Type + + Dense [2]*nn.Linear `gguf:"dense"` +} + +func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + for _, dense := range m.Dense { + hiddenStates = dense.Forward(ctx, hiddenStates) + } + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + return hiddenStates, nil +} + +func newEmbedModel(c fs.Config) (model.Model, error) { + m := &embedModel{ + SentencePiece: model.NewSentencePiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + TextModel: newTextModel(c), + poolingType: pooling.Type(c.Uint("pooling_type", 0)), + } + + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + + return m, nil +} diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 53bf82758..27da889e4 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,9 +16,9 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece - *VisionModel `gguf:"v,vision"` + *VisionModel `gguf:"v"` *TextModel *MultiModalProjector `gguf:"mm"` @@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return []input.Multimodal{{Tensor: visionOutputs}}, nil } -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input for _, inp := range inputs { if len(inp.Multimodal) == 0 { @@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { inputMultimodal := inp.Multimodal[0].Tensor result = append(result, - input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" - input.Input{Token: 255999}, // """ - input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + &input.Input{Token: 255999}, // """ + &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder ) // add image token placeholders - result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) + result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, - input.Input{Token: 256000}, // - input.Input{Token: 108}, // "\n\n" + &input.Input{Token: 256000}, // + &input.Input{Token: 108}, // "\n\n" ) } } @@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) + return m.Output.Forward(ctx, hiddenStates), nil } func init() { model.Register("gemma3", New) + model.Register("gemma3_embed", newEmbedModel) } diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 70d7797e9..631baeccd 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: 1, + // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights + // (8 instead of 1) + // ropeScale: c.Float("rope.scaling.factor", 1.0), }, } @@ -84,7 +87,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -95,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct { @@ -123,7 +126,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -159,8 +162,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) +func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) // set image embeddings @@ -191,12 +196,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return m.Output.Forward(ctx, hiddenState) + return hiddenState } diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go index 6e83a9724..e59e3193f 100644 --- a/model/models/gemma3n/model.go +++ b/model/models/gemma3n/model.go @@ -10,7 +10,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel } @@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func New(c fs.Config) (model.Model, error) { m := Model{ TextModel: newTextModel(c), - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index b75a2abb3..d0e9a026a 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) - hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) + hiddenStates = hiddenStates.Rows(ctx, batch.Outputs) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) return m.Output.Forward(ctx, hiddenStates), nil @@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.ropeBaseLocal } - return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } type TextScaledWordEmbedding struct { @@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position } active = d.PerLayerInputGate.Forward(ctx, active) - active = active.GELU(ctx) - active = active.Mul(ctx, perLayerInput) + active = active.GELU(ctx, perLayerInput) active = d.PerLayerProjection.Forward(ctx, active) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) @@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = attn.QueryNorm.Forward(ctx, query, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) var key, value ml.Tensor if !sharedKV { key = attn.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = attn.KeyNorm.Forward(ctx, key, opts.eps) - key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) value = attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) @@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx) } - hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates) + hiddenStates = hiddenStates.GELU(ctx, upStates) hiddenStates = mlp.Down.Forward(ctx, hiddenStates) return hiddenStates } @@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeBase: c.Float("rope.freq_base", 1_000_000), ropeBaseLocal: c.Float("rope.freq_base_local", 10_000), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), activationSparsityScale: c.Floats("activation_sparsity_scale"), diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 3ef078095..6a3270651 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } var outputs ml.Tensor - if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if i == len(m.TransformerBlocks)-1 { + outputs = batch.Outputs } hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) @@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts * up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) } - hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7) + hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) @@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) { m := Transformer{ TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - strings.Join([]string{ - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `\p{N}{1,3}`, - ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, - `\s*[\r\n]+`, - `\s+(?!\S)`, - `\s+`, - }, "|"), - ), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + strings.Join([]string{ + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `\p{N}{1,3}`, + ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, + `\s*[\r\n]+`, + `\s+(?!\S)`, + `\s+`, + }, "|"), ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 77d8f36d3..c03f04a0d 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -2,7 +2,6 @@ package llama import ( "cmp" - "fmt" "math" "github.com/ollama/ollama/fs" @@ -23,51 +22,80 @@ type Options struct { type Model struct { model.Base - model.BytePairEncoding + model.TextProcessor TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` - *Options + Options } func New(c fs.Config) (model.Model, error) { - // This model currently only supports the gpt2 tokenizer - if c.String("tokenizer.ggml.model") == "llama" { - return nil, fmt.Errorf("unsupported tokenizer: llama") + if c.Uint("expert_count") > 0 { + // TODO: support mixtures of experts + return nil, model.ErrUnsupportedModel } - // Best effort detection of library/deepseek-coder model(s) which are incompatible - if c.String("general.name") == "deepseek-ai" { - return nil, fmt.Errorf("unsupported model: %s", c.String("general.name")) - } - m := Model{ - BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Ints("tokenizer.ggml.token_type"), - Merges: c.Strings("tokenizer.ggml.merges"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), - }, + + var processor model.TextProcessor + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., ), - Layers: make([]Layer, c.Uint("block_count")), - Options: &Options{ + } + switch c.String("tokenizer.ggml.model") { + case "gpt2": + var pretokenizers []string + switch c.String("tokenizer.ggml.pre") { + case "default": + // no-op use the default bpe pretokenizer + case "qwen2": + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + case "refact": + pretokenizers = []string{ + `\p{N}`, + `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`, + } + case "tekken": + pretokenizers = []string{ + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + default: + // use a llama-style pretokenizer + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + } + processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...) + case "llama": + processor = model.NewSentencePiece(&vocabulary) + default: + return nil, model.ErrUnsupportedTokenizer + } + + m := Model{ + TextProcessor: processor, + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeBase: c.Float("rope.freq_base", 1e5), + ropeScale: c.Float("rope.scaling.factor", 1), }, } @@ -98,8 +126,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -108,7 +136,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { @@ -118,7 +146,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -160,10 +188,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } - hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 8084760b0..e80fbaed6 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -18,7 +18,7 @@ type Model struct { model.BytePairEncoding ImageProcessor - *VisionModel `gguf:"v,vision"` + *VisionModel `gguf:"v"` *Projector `gguf:"mm"` *TextModel } @@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), @@ -134,16 +133,16 @@ type separator struct { y bool } -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input for _, inp := range inputs { if len(inp.Multimodal) == 0 { result = append(result, inp) continue } - var imageInputs []input.Input - imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> + var imageInputs []*input.Input + imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|> for i, mm := range inp.Multimodal { patchesPerChunk := mm.Tensor.Dim(1) @@ -151,20 +150,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { if i < len(inp.Multimodal)-1 { separator := mm.Data.(*separator) - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...) if separator.x { - imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> + imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|> } if separator.y { - imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|> } } else { - imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) - imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> + imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|> + imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|> } } @@ -176,9 +175,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 045ab403f..e056391f5 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) } if opts.useQKNorm { @@ -58,14 +58,14 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } type TextExperts struct { - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { @@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Mul(ctx, scores) - upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) - gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) - downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + upStates := e.Up.Forward(ctx, hiddenStates, experts) + gateStates := e.Gate.Forward(ctx, hiddenStates, experts) + downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { @@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens return nextStates } -// TextSharedExpert is TextMLP with different tensor names -type TextSharedExpert struct { - Gate *nn.Linear `gguf:"ffn_gate_shexp"` - Up *nn.Linear `gguf:"ffn_up_shexp"` - Down *nn.Linear `gguf:"ffn_down_shexp"` -} - -func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) - return mlp.Down.Forward(ctx, hiddenStates) -} - type TextMOE struct { Router *nn.Linear `gguf:"ffn_gate_inp"` Experts *TextExperts - SharedExpert *TextSharedExpert + SharedExpert *TextMLP `gguf:",suf:_shexp"` } func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { @@ -196,7 +184,7 @@ func newTextModel(c fs.Config) *TextModel { numExpertsUsed: int(c.Uint("expert_used_count")), ropeDim: int(c.Uint("rope.dimension_count")), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), eps: c.Float("attention.layer_norm_rms_epsilon"), interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), noRopeInterval: int(c.Uint("no_rope_interval", 4)), @@ -248,5 +236,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 9d662fc11..5c46615e9 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -18,7 +18,7 @@ type Model struct { model.BytePairEncoding *TextModel - *VisionModel `gguf:"v,vision"` + *VisionModel `gguf:"v"` *MultiModalProjector `gguf:"mm"` ImageProcessor @@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: newTextModel(c), VisionModel: newVisionModel(c), @@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] // Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings // that can be processed together. -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input for _, inp := range inputs { if len(inp.Multimodal) == 0 { result = append(result, inp) } else { for i, row := range inp.Multimodal { // [IMG] - result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) + result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) + result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) if i == len(inp.Multimodal)-1 { // [IMG_END] - result = append(result, input.Input{Token: 13}) + result = append(result, &input.Input{Token: 13}) } else { // [IMG_BREAK] - result = append(result, input.Input{Token: 12}) + result = append(result, &input.Input{Token: 12}) } } } @@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 19c36f9fe..d2e2eac6c 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil } type MLP struct { @@ -65,7 +65,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), }, } } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 65bdcff2a..3bfb8c90a 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -51,7 +51,7 @@ type VisionMLP struct { } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 45cb3e02c..769743694 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -17,7 +17,7 @@ type Model struct { model.Base model.BytePairEncoding - *VisionModel `gguf:"v,vision"` + *VisionModel `gguf:"v"` *TextModel Projector *nn.Linear `gguf:"mm.0"` @@ -33,7 +33,6 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), @@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return []input.Multimodal{{Tensor: projectedOutputs}}, nil } -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { for i := range inputs { if inputs[i].Multimodal != nil { inputs[i].Token = 128256 // <|image|> @@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 47a518ced..65f0a8278 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil } return key, nil @@ -58,7 +58,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } diff --git a/model/models/models.go b/model/models/models.go index c880a4720..0cda615af 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -1,6 +1,8 @@ package models import ( + _ "github.com/ollama/ollama/model/models/bert" + _ "github.com/ollama/ollama/model/models/deepseek2" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c662f068..2e2347102 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, value := attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -59,7 +59,7 @@ type MLP struct { } func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) @@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } func New(c fs.Config) (model.Model, error) { @@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ Layers: make([]DecoderLayer, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), @@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) { headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), eps: c.Float("attention.layer_norm_rms_epsilon"), }, } diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index ee38cad92..6898e38ca 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -18,7 +18,7 @@ type Model struct { model.BytePairEncoding *TextModel - *VisionModel `gguf:"v,vision"` + *VisionModel `gguf:"v"` ImageProcessor } @@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: NewTextModel(c), VisionModel: newVisionModel(c), @@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input } // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input var ( imageToken int32 = 151655 @@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { return nil, fmt.Errorf("failed to encode image prompt: %w", err) } for i := range pre { - result = append(result, input.Input{Token: pre[i]}) + result = append(result, &input.Input{Token: pre[i]}) } patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) // First add the vision start token - result = append(result, input.Input{Token: visionStartToken}) + result = append(result, &input.Input{Token: visionStartToken}) // Add the image token with the multimodal tensor data at the first position - result = append(result, input.Input{ + result = append(result, &input.Input{ Token: imageToken, Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, @@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { }) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) - result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...) + result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...) - result = append(result, input.Input{Token: visionEndToken}) + result = append(result, &input.Input{Token: visionEndToken}) } } @@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) } func init() { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 4b6bc1666..e6c6e6c19 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel { originalContextLength: int(c.Uint("context_length", 128000)), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), }, } @@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil } // MLP implements the feed-forward network component with SwiGLU activation @@ -90,7 +90,7 @@ type MLP struct { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { // Apply SwiGLU activation gating - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) // Project back to hidden dimension return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 4d7afaa14..3dd60e3ba 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -100,8 +100,7 @@ type VisionMLP struct { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { // Using activation as specified in config (likely GELU or SiLU/Swish) gateOutput := mlp.Gate.Forward(ctx, hiddenStates) - upOutput := mlp.Up.Forward(ctx, hiddenStates) - hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput) + hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go new file mode 100644 index 000000000..c03888d45 --- /dev/null +++ b/model/models/qwen3/embed.go @@ -0,0 +1,73 @@ +package qwen3 + +import ( + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type embedModel struct { + model.Base + model.BytePairEncoding + + *Model + poolingType pooling.Type +} + +func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates, err := m.forward(ctx, batch) + if err != nil { + return nil, err + } + + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + return hiddenStates, nil +} + +func newEmbed(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + for i := range layers { + layers[i].MLP = &dense{} + } + m := embedModel{ + BytePairEncoding: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + ), + Model: &Model{ + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), + }, + }, + poolingType: pooling.Type(c.Uint("pooling_type")), + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 7a83e0d04..cc58e4a28 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -30,10 +30,10 @@ func (o Options) headDim() int { } type Attention struct { - QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Query *nn.Linear `gguf:"attn_q"` - KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` } @@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) @@ -65,10 +65,10 @@ type MLP interface { } type sparse struct { - Router *nn.Linear `gguf:"ffn_gate_inp"` - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { @@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) - upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts)) - hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) - hiddenStates = hiddenStates.SILU(ctx) - hiddenStates = hiddenStates.Mul(ctx, upStates) - - experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) @@ -111,7 +107,8 @@ type dense struct { } func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates). + SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -154,29 +151,39 @@ type Model struct { *Options } -// Forward implements model.Model. func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates, err := m.forward(ctx, batch) + if err != nil { + return nil, err + } + + return m.Output.Forward(ctx, hiddenStates), nil +} + +// Forward implements model.Model. +func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { - m.Cache.SetLayer(i) + if m.Cache != nil { + m.Cache.SetLayer(i) + } var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) } - hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) - return m.Output.Forward(ctx, hiddenStates), nil + return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } var _ model.Model = (*Model)(nil) @@ -193,7 +200,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -206,6 +212,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Layers: layers, Options: &Options{ @@ -216,7 +223,7 @@ func New(c fs.Config) (model.Model, error) { valueLength: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), numExperts: int(c.Uint("expert_count")), numExpertsUsed: int(c.Uint("expert_used_count")), normTopKProb: c.Bool("norm_top_k_prob", true), @@ -230,4 +237,5 @@ func New(c fs.Config) (model.Model, error) { func init() { model.Register("qwen3", New) model.Register("qwen3moe", New) + model.Register("qwen3_embed", newEmbed) } diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go new file mode 100644 index 000000000..a1d4e8127 --- /dev/null +++ b/model/parsers/parsers.go @@ -0,0 +1,49 @@ +package parsers + +import ( + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/harmony" +) + +type Parser interface { + // Init initializes the parser with tools and optional last message for chat prefill + // Returns processed tools if the parser needs to modify them (e.g., harmony renames them) + Init(tools []api.Tool, lastMessage *api.Message) []api.Tool + // Add processes streamed content and returns parsed content, thinking, and tool calls + // The done flag indicates if this is the last chunk (used for draining accumulators) + Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) + HasToolSupport() bool + HasThinkingSupport() bool +} + +func ParserForName(name string) Parser { + switch name { + case "qwen3-coder": + parser := &Qwen3CoderParser{} + return parser + case "passthrough": + return &PassthroughParser{} + case "harmony": + return harmony.NewHarmonyMessageHandler() + default: + return nil + } +} + +type PassthroughParser struct{} + +func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + return tools // passthrough doesn't modify tools +} + +func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + return s, "", nil, nil +} + +func (p *PassthroughParser) HasToolSupport() bool { + return false +} + +func (p *PassthroughParser) HasThinkingSupport() bool { + return false +} diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go new file mode 100644 index 000000000..f44d7c8ef --- /dev/null +++ b/model/parsers/qwen3coder.go @@ -0,0 +1,463 @@ +package parsers + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "log/slog" + "math" + "regexp" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type qwenParserState int + +const ( + toolOpenTag = "" + toolCloseTag = "" +) + +const ( + qwenParserState_LookingForToolStart qwenParserState = iota + qwenParserState_CollectingToolContent +) + +type Qwen3CoderParser struct { + state qwenParserState + acc strings.Builder + tools []api.Tool +} + +func (p *Qwen3CoderParser) HasToolSupport() bool { + return true +} + +func (p *Qwen3CoderParser) HasThinkingSupport() bool { + return false +} + +func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + p.tools = tools + return tools // Qwen doesn't modify tools +} + +func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.acc.WriteString(s) + + events := p.parseEvents() + + var toolCalls []api.ToolCall + var sb strings.Builder + for _, event := range events { + switch event := event.(type) { + case qwenEventRawToolCall: + toolCall, err := parseToolCall(event, p.tools) + if err != nil { + slog.Warn("qwen tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case qwenEventContent: + // TODO(drifkin): if the same turn contains multiple interleaved content + // events, we naively append them together here. See the note below about + // `qwenEvent`s for more details + sb.WriteString(event.content) + } + } + + return sb.String(), "", toolCalls, nil +} + +func (p *Qwen3CoderParser) parseEvents() []qwenEvent { + var all []qwenEvent + + keepLooping := true + for keepLooping { + var events []qwenEvent + events, keepLooping = eat(p) + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String()) + } + + return all +} + +// we use some internal event types in order to communicate between `Add` and +// `eat`. We do this to support interleaving content and parallel tool calls in +// the parser, even though qwen3-coder isn't supposed to do this. Our API +// doesn't currently support models outputting multiple messages in a turn, so +// we wouldn't be able to represent it yet, but there's no reason to prevent the +// parser from supporting it, especially for future models if they end up using +// a similar format. +type qwenEvent interface { + isQwenEvent() +} + +type qwenEventRawToolCall struct { + raw string +} + +type qwenEventContent struct { + content string +} + +func (qwenEventContent) isQwenEvent() {} +func (qwenEventRawToolCall) isQwenEvent() {} + +// eat consumes the parser's buffer, and returns a list of any unambiguous +// events from the current parser state. If the parser transitions to another +// state, it may have additional events to emit on the next call, which is what +// the second return value indicates +func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) { + var events []qwenEvent + + switch p.state { + case qwenParserState_LookingForToolStart: + if strings.Contains(p.acc.String(), toolOpenTag) { + // we found a full tool open tag, so we can emit the content before the + // tag, being sure to trim any trailing whitespace + split := strings.SplitN(p.acc.String(), toolOpenTag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, qwenEventContent{content: before}) + } + after := split[1] + p.acc.Reset() + p.acc.WriteString(after) + p.state = qwenParserState_CollectingToolContent + return events, true + } else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 { + // we found a partial tool open tag, so we can emit the unambiguous part, + // which is the (trailing-whitespace trimmed) content before the partial + // tool open tag + beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + events = append(events, qwenEventContent{content: unambiguous}) + return events, false + } else { + // we found content that is entirely not a tool call. We should withhold + // any trailing whitespace in case this is the end of the content + whitespaceLen := trailingWhitespaceLen(p.acc.String()) + ambiguousStart := len(p.acc.String()) - whitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } + return events, false + } + case qwenParserState_CollectingToolContent: + if strings.Contains(p.acc.String(), toolCloseTag) { + split := strings.SplitN(p.acc.String(), toolCloseTag, 2) + before := split[0] + if len(before) == 0 { + slog.Warn("qwen tool call closing tag found but no content before it") + } + // remove any whitespace between the tool call and any content after it + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.acc.Reset() + p.acc.WriteString(after) + events = append(events, qwenEventRawToolCall{raw: before}) + p.state = qwenParserState_LookingForToolStart + return events, true + } else { + // note that we don't need to check the overlap here because we only plan + // on parsing the tool call once we see the full closing tag. We don't + // stream back the unparsed tool content, so there's no need to be eager + // here + return events, false + } + default: + panic("unreachable") + } +} + +// TODO(drifkin): move this to a shared location +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +func trailingWhitespaceLen(s string) int { + remaining := s + total := 0 + for len(remaining) > 0 { + r, size := utf8.DecodeLastRuneInString(remaining) + // if it's an invalid utf8 rune, assume it isn't whitespace + if r == utf8.RuneError && size == 1 { + break + } + if !unicode.IsSpace(r) { + break + } + total += size + remaining = remaining[:len(remaining)-size] + } + return total +} + +type XMLFunctionCall struct { + XMLName xml.Name `xml:"function"` + Name string `xml:"name,attr"` + Parameters []XMLParameter `xml:"parameter"` +} + +type XMLParameter struct { + Name string `xml:"name,attr"` + Value string `xml:",chardata"` +} + +// parseToolCall parses a raw tool call string into an api.ToolCall. +// The raw string follows an xml-like format, here's an example: +// +// +// +// San Francisco +// +// +// celsius +// +// +func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + xmlString := transformToXML(raw.raw) + + var functionCall XMLFunctionCall + err := xml.Unmarshal([]byte(xmlString), &functionCall) + if err != nil { + return api.ToolCall{}, err + } + + toolCall.Function = api.ToolCallFunction{ + Name: functionCall.Name, + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionCall.Name { + matchedTool = &tools[i] + break + } + } + + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + for _, parameter := range functionCall.Parameters { + // Look up the parameter type if we found the tool + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok { + paramType = prop.Type + } + } + + toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType) + } + + return toolCall, nil +} + +// parseValue converts a raw string value to the appropriate type based on the parameter type specification. +// +// For union types (multiple types in PropertyType, which we support but doesn't +// seem as though the reference parser does type coercion with those types in +// mind) we use a type precedence approach: +// 1. null - checked first regardless of declared types (matches reference implementation) +// 2. boolean - only "true"/"false" are valid booleans +// 3. integer - must parse as a whole number +// 4. number - must parse as numeric (returns int if no decimal part) +// 5. array - must parse as valid JSON array +// 6. object - must parse as valid JSON object +// 7. string - always succeeds (least specific type) +// +// This precedence ensures we return the most specific type that successfully parses, +// following the principle of least surprise. For example, with PropertyType{"string", "number"}, +// "123" becomes 123 (number), while "hello" becomes "hello" (string). +func parseValue(raw string, paramType api.PropertyType) any { + // first remove a single leading newlines, and a single trailing newline (if + // they exist). This follows the reference implementation + raw = strings.TrimPrefix(raw, "\n") + raw = strings.TrimSuffix(raw, "\n") + + // Check for null first (case-insensitive) - this takes precedence over any type + if strings.ToLower(raw) == "null" { + return nil + } + + // If no type is specified, default to string + if len(paramType) == 0 { + return raw + } + + // Check if any of the specified types match, using type precedence + // Order: boolean -> integer -> number -> array -> object -> string + typeSet := make(map[string]bool) + for _, t := range paramType { + typeSet[t] = true + } + + // Try boolean first (most restrictive) + if typeSet["boolean"] { + lower := strings.ToLower(raw) + switch lower { + case "true": + return true + case "false": + return false + } + // If not a valid boolean but boolean is the only type, return false (matching reference) + if len(paramType) == 1 { + return false + } + // Otherwise try other types + } + + // Try integer + if typeSet["integer"] { + if i, err := strconv.ParseInt(raw, 10, 64); err == nil { + // Return as int if it fits in int32, otherwise int64 + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + // If integer is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try number (float) + if typeSet["number"] { + if f, err := strconv.ParseFloat(raw, 64); err == nil { + // If the number has no decimal part, return as int (matching reference) + if f == math.Trunc(f) { + i := int64(f) + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + return f + } + // If number is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try array + if typeSet["array"] { + var arr []any + if err := json.Unmarshal([]byte(raw), &arr); err == nil { + return arr + } + // If array is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try object + if typeSet["object"] { + var obj map[string]any + if err := json.Unmarshal([]byte(raw), &obj); err == nil { + return obj + } + // If object is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // String always succeeds (or if "string" is in the type set) + if typeSet["string"] { + return raw + } + + // If we get here, none of the types matched and string wasn't an option + // We return string as a fallback. The reference implementation will attempt + // to parse the value as a python literal, but we purposefully don't support + // that + return raw +} + +var ( + qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) + qwenXMLTagRegex = regexp.MustCompile(``) +) + +// transformToXML transforms a raw qwen tool call with xml-like tags into valid +// xml so that it can be parsed by any xml parser +func transformToXML(raw string) string { + // take the form `` and transform it to ``, taking + // care to properly escape the string that becomes the attribute value + transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { + groups := qwenTagRegex.FindStringSubmatch(match) + tag := groups[1] + var escapedValue strings.Builder + xml.EscapeText(&escapedValue, []byte(groups[2])) + return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) + }) + + // Walk the resulting string, escaping any character data that sits between the + // xml tags we just emitted + var out strings.Builder + lastIdx := 0 + for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) { + if loc[0] > lastIdx { + escapeTextNode(&out, transformed[lastIdx:loc[0]]) + } + out.WriteString(transformed[loc[0]:loc[1]]) + lastIdx = loc[1] + } + if lastIdx < len(transformed) { + escapeTextNode(&out, transformed[lastIdx:]) + } + + return out.String() +} + +// escapeTextNode escapes XML character data without altering other characters +// like newlines or tabs (which is why we don't use xml.EscapeText for this) +func escapeTextNode(sb *strings.Builder, s string) { + for _, r := range s { + switch r { + case '&': + sb.WriteString("&") + case '<': + sb.WriteString("<") + case '>': + sb.WriteString(">") + default: + sb.WriteRune(r) + } + } +} diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go new file mode 100644 index 000000000..c77fe2d95 --- /dev/null +++ b/model/parsers/qwen3coder_test.go @@ -0,0 +1,1095 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +// tool creates a test tool with the given name and properties +func tool(name string, props map[string]api.ToolProperty) api.Tool { + t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}} + t.Function.Parameters.Type = "object" + t.Function.Parameters.Properties = props + return t +} + +func TestQwenParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple message streamed word by word", + steps: []step{ + { + input: "hi", + wantEvents: []qwenEvent{qwenEventContent{content: "hi"}}, + }, + { + input: " there", + wantEvents: []qwenEvent{qwenEventContent{content: " there"}}, + }, + }, + }, + { + desc: "content before tool call", + steps: []step{ + { + input: "hi there", + wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}}, + }, + }, + }, + { + desc: "multiple tool calls in one message", + steps: []step{ + { + input: "before1in tool callafter1in tool call 2after2", + wantEvents: []qwenEvent{ + qwenEventContent{content: "before1"}, + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "after1"}, + qwenEventRawToolCall{raw: "in tool call 2"}, + qwenEventContent{content: "after2"}, + }, + }, + }, + }, + { + desc: "tool calls with split tags", + steps: []step{ + { + input: "beforein tool callaf", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "af"}, + }, + }, + { + input: "ter", + wantEvents: []qwenEvent{ + qwenEventContent{content: "ter"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between content and tool call", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventContent{content: "abc"}, + qwenEventRawToolCall{raw: "def"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between tool call and content", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "empty content before tool call", + steps: []step{ + { + input: "\nabc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, + { + desc: "partial tool open tag fakeout", + steps: []step{ + { + input: "abc\ntestمرحبا", + wantEvents: []qwenEvent{ + qwenEventContent{content: "你好 🌍"}, + qwenEventRawToolCall{raw: "test"}, + qwenEventContent{content: "مرحبا"}, + }, + }, + }, + }, + { + desc: "arabic text handling", + steps: []step{ + { + input: "مرحبا بالعالم", + wantEvents: []qwenEvent{qwenEventContent{content: "مرحبا بالعالم"}}, + }, + }, + }, + { + desc: "emoji passthrough", + steps: []step{ + { + input: "✅", + wantEvents: []qwenEvent{qwenEventContent{content: "✅"}}, + }, + }, + }, + { + desc: "emoji after tool call", + steps: []step{ + { + input: "test完成 ✅", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "test"}, + qwenEventContent{content: "完成 ✅"}, + }, + }, + }, + }, + { + desc: "unicode streaming with whitespace handling", + steps: []step{ + { + input: "مرحبا", + wantEvents: []qwenEvent{ + qwenEventContent{content: "مرحبا"}, + }, + }, + { + input: " \n", + wantEvents: []qwenEvent{}, + }, + { + input: "世界", + wantEvents: []qwenEvent{ + qwenEventContent{content: " \n世界"}, + }, + }, + }, + }, + { + desc: "non-breaking space withheld across chunks", + steps: []step{ + { + input: "Hello\u00a0", + wantEvents: []qwenEvent{ + qwenEventContent{content: "Hello"}, + }, + }, + { + input: "world", + wantEvents: []qwenEvent{ + qwenEventContent{content: "\u00a0world"}, + }, + }, + }, + }, + { + desc: "ideographic space before partial tool", + steps: []step{ + { + input: "Hello\u3000abc", + wantEvents: []qwenEvent{}, + }, + { + input: "def", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "ideographic space before partial tool fakeout", + steps: []step{ + { + input: "Hello\u3000abc", + wantEvents: []qwenEvent{ + qwenEventContent{content: "\u3000abc"}, + }, + }, + }, + }, + { + desc: "unicode with partial tool tag", + steps: []step{ + { + input: "测试🎯 + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_temperature", + Arguments: map[string]any{ + "location": "San Francisco", + "unit": "celsius", + }, + }, + }, + }, + { + name: "names with spaces", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get current temperature", + Arguments: map[string]any{ + "location with spaces": "San Francisco", + "unit with spaces": "celsius", + }, + }, + }, + }, + // this mirrors the reference implementation's behavior, but unclear if it + // ever happens. If so, then we should probably remove them instead, this + // test is to just document the current behavior and test that we don't get + // xml errors + { + name: "names with quotes", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +"celsius" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "\"get current temperature\"", + Arguments: map[string]any{ + "\"location with spaces\"": "San Francisco", + "\"unit with spaces\"": "\"celsius\"", + }, + }, + }, + }, + { + name: "tool call with typed parameters", + tools: []api.Tool{ + tool("calculate", map[string]api.ToolProperty{ + "x": {Type: api.PropertyType{"number"}}, + "y": {Type: api.PropertyType{"integer"}}, + "enabled": {Type: api.PropertyType{"boolean"}}, + "items": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: ` + +3.14 + + +42 + + +true + + +["a", "b", "c"] + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: map[string]any{ + "x": 3.14, + "y": 42, + "enabled": true, + "items": []any{"a", "b", "c"}, + }, + }, + }, + }, + // regression test for + { + name: "ampersands in parameter values", + tools: []api.Tool{}, + rawToolCall: ` + +ls && echo "done" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"done\"", + }, + }, + }, + }, + { + name: "angle brackets in parameter values", + tools: []api.Tool{}, + rawToolCall: ` + +ls && echo "a > b and a < b" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"a > b and a < b\"", + }, + }, + }, + }, + { + name: "unicode in function names and parameters", + tools: []api.Tool{}, + rawToolCall: ` + +北京 + + +Hello! 你好! 🌟 مرحبا + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: map[string]any{ + "城市": "北京", + "message": "Hello! 你好! 🌟 مرحبا", + }, + }, + }, + }, + } + + for i, step := range steps { + gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools) + if err != nil { + t.Errorf("step %d (%s): %v", i, step.name, err) + } + if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) + } + } +} + +func TestTrailingWhitespaceLenUnicode(t *testing.T) { + cases := []struct { + name string + input string + want int + }{ + { + name: "ascii space", + input: "Hello ", + want: 1, + }, + { + name: "non-breaking space", + input: "Hello\u00a0", + want: 2, + }, + { + name: "ideographic space", + input: "Hello\u3000", + want: 3, + }, + { + name: "multiple runes of whitespace", + input: "Hi\u00a0\u3000", + want: 5, + }, + } + + for _, tc := range cases { + got := trailingWhitespaceLen(tc.input) + if got != tc.want { + t.Errorf("%s: trailingWhitespaceLen(%q) = %d, want %d", tc.name, tc.input, got, tc.want) + } + } +} + +func TestQwenToolCallValueParsing(t *testing.T) { + cases := []struct { + desc string + raw string + paramType api.PropertyType + want any + }{ + { + desc: "default string value (no type specified)", + paramType: api.PropertyType{}, + raw: "some-string", + want: "some-string", + }, + { + desc: "trim a single leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\nsome-string\n", + want: "some-string", + }, + { + desc: "trim at most one leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\n\nsome-string\n\n", + want: "\nsome-string\n", + }, + { + desc: "newline really has to be the first character to be trimmed", + paramType: api.PropertyType{}, + raw: " \nsome-string\n ", + want: " \nsome-string\n ", + }, + { + desc: "numeric type", + paramType: api.PropertyType{"number"}, + raw: "123", + want: 123, + }, + // Integer parsing tests + { + desc: "integer type", + paramType: api.PropertyType{"integer"}, + raw: "42", + want: 42, + }, + { + desc: "negative integer", + paramType: api.PropertyType{"integer"}, + raw: "-100", + want: -100, + }, + { + desc: "zero integer", + paramType: api.PropertyType{"integer"}, + raw: "0", + want: 0, + }, + { + desc: "integer with leading zeros", + paramType: api.PropertyType{"integer"}, + raw: "007", + want: 7, + }, + { + desc: "large integer", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Just beyond int32 max + want: int64(2147483648), + }, + // Float/number parsing tests + { + desc: "float type", + paramType: api.PropertyType{"number"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "negative float", + paramType: api.PropertyType{"number"}, + raw: "-273.15", + want: -273.15, + }, + { + desc: "float without decimal part", + paramType: api.PropertyType{"number"}, + raw: "100.0", + want: 100, + }, + { + desc: "scientific notation positive", + paramType: api.PropertyType{"number"}, + raw: "1.23e5", + want: 123000, // Will be int since it has no decimal part + }, + { + desc: "scientific notation negative", + paramType: api.PropertyType{"number"}, + raw: "1.5e-3", + want: 0.0015, + }, + { + desc: "very small float", + paramType: api.PropertyType{"number"}, + raw: "0.00000001", + want: 0.00000001, + }, + // String parsing tests + { + desc: "explicit string type", + paramType: api.PropertyType{"string"}, + raw: "hello world", + want: "hello world", + }, + { + desc: "string with special characters", + paramType: api.PropertyType{"string"}, + raw: "/usr/local/bin/test-file_v2.0.sh", + want: "/usr/local/bin/test-file_v2.0.sh", + }, + { + desc: "string with quotes", + paramType: api.PropertyType{"string"}, + raw: `He said "hello" to me`, + want: `He said "hello" to me`, + }, + { + desc: "multiline string", + paramType: api.PropertyType{"string"}, + raw: "line one\nline two\nline three", + want: "line one\nline two\nline three", + }, + { + desc: "empty string", + paramType: api.PropertyType{"string"}, + raw: "", + want: "", + }, + { + desc: "string that looks like a number", + paramType: api.PropertyType{"string"}, + raw: "12345", + want: "12345", + }, + // Boolean parsing tests + { + desc: "boolean true", + paramType: api.PropertyType{"boolean"}, + raw: "true", + want: true, + }, + { + desc: "boolean false", + paramType: api.PropertyType{"boolean"}, + raw: "false", + want: false, + }, + { + desc: "boolean case insensitive true", + paramType: api.PropertyType{"boolean"}, + raw: "True", + want: true, + }, + { + desc: "boolean case insensitive false", + paramType: api.PropertyType{"boolean"}, + raw: "FALSE", + want: false, + }, + // Null parsing tests + { + desc: "null value lowercase", + paramType: api.PropertyType{"string"}, + raw: "null", + want: nil, + }, + { + desc: "null value case insensitive", + paramType: api.PropertyType{"integer"}, + raw: "NULL", + want: nil, + }, + // Array parsing tests + { + desc: "array of strings", + paramType: api.PropertyType{"array"}, + raw: `["foo", "bar", "baz"]`, + want: []any{"foo", "bar", "baz"}, + }, + { + desc: "array of numbers", + paramType: api.PropertyType{"array"}, + raw: `[1, 2.5, 3]`, + want: []any{float64(1), 2.5, float64(3)}, + }, + { + desc: "array of mixed types", + paramType: api.PropertyType{"array"}, + raw: `["string", 123, true, null]`, + want: []any{"string", float64(123), true, nil}, + }, + { + desc: "empty array", + paramType: api.PropertyType{"array"}, + raw: `[]`, + want: []any{}, + }, + // Object parsing tests + { + desc: "simple object", + paramType: api.PropertyType{"object"}, + raw: `{"key": "value", "number": 42}`, + want: map[string]any{"key": "value", "number": float64(42)}, + }, + { + desc: "nested object", + paramType: api.PropertyType{"object"}, + raw: `{"outer": {"inner": "value"}}`, + want: map[string]any{"outer": map[string]any{"inner": "value"}}, + }, + { + desc: "empty object", + paramType: api.PropertyType{"object"}, + raw: `{}`, + want: map[string]any{}, + }, + // Error cases and fallback behavior + { + desc: "invalid integer falls back to string", + paramType: api.PropertyType{"integer"}, + raw: "not-a-number", + want: "not-a-number", + }, + { + desc: "invalid float falls back to string", + paramType: api.PropertyType{"number"}, + raw: "3.14.159", + want: "3.14.159", + }, + { + desc: "invalid boolean falls back to false", + paramType: api.PropertyType{"boolean"}, + raw: "yes", + want: false, + }, + { + desc: "invalid JSON array falls back to string", + paramType: api.PropertyType{"array"}, + raw: "[1, 2, unclosed", + want: "[1, 2, unclosed", + }, + { + desc: "invalid JSON object falls back to string", + paramType: api.PropertyType{"object"}, + raw: `{"key": unclosed`, + want: `{"key": unclosed`, + }, + // Edge cases + { + desc: "integer overflow should use int64", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Beyond int32 max + want: int64(2147483648), + }, + { + desc: "float with many decimal places", + paramType: api.PropertyType{"number"}, + raw: "3.141592653589793", + want: 3.141592653589793, + }, + { + desc: "string with JSON-like content", + paramType: api.PropertyType{"string"}, + raw: `{"this": "is", "just": "a string"}`, + want: `{"this": "is", "just": "a string"}`, + }, + { + desc: "whitespace-only string", + paramType: api.PropertyType{"string"}, + raw: " ", + want: " ", + }, + // Unknown parameter (no type specified in tools) + { + desc: "parameter not in tool definition defaults to string", + paramType: api.PropertyType{}, + raw: "some value", + want: "some value", + }, + // Union type tests + { + desc: "string or number union - valid number", + paramType: api.PropertyType{"string", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "string or number union - non-numeric string", + paramType: api.PropertyType{"string", "number"}, + raw: "hello", + want: "hello", + }, + { + desc: "number or string union - valid number (order shouldn't matter)", + paramType: api.PropertyType{"number", "string"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "integer or null union - valid integer", + paramType: api.PropertyType{"integer", "null"}, + raw: "123", + want: 123, + }, + { + desc: "integer or null union - null value", + paramType: api.PropertyType{"integer", "null"}, + raw: "null", + want: nil, + }, + { + desc: "null or integer union - null value (order shouldn't matter)", + paramType: api.PropertyType{"null", "integer"}, + raw: "null", + want: nil, + }, + { + desc: "boolean or string union - valid boolean", + paramType: api.PropertyType{"boolean", "string"}, + raw: "true", + want: true, + }, + { + desc: "boolean or string union - non-boolean becomes string", + paramType: api.PropertyType{"boolean", "string"}, + raw: "yes", + want: "yes", + }, + { + desc: "string or boolean union - valid boolean (precedence test)", + paramType: api.PropertyType{"string", "boolean"}, + raw: "false", + want: false, // Should be boolean, not string "false" + }, + { + desc: "integer or number union - integer value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42", + want: 42, + }, + { + desc: "integer or number union - float value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "number or integer union - integer value (precedence test)", + paramType: api.PropertyType{"number", "integer"}, + raw: "42", + want: 42, // Should try integer first due to precedence + }, + { + desc: "array or object union - valid array", + paramType: api.PropertyType{"array", "object"}, + raw: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + desc: "array or object union - valid object", + paramType: api.PropertyType{"array", "object"}, + raw: `{"key": "value"}`, + want: map[string]any{"key": "value"}, + }, + { + desc: "object or array union - valid array (precedence test)", + paramType: api.PropertyType{"object", "array"}, + raw: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + desc: "complex multi-type union - null", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "null", + want: nil, + }, + { + desc: "complex multi-type union - boolean", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "true", + want: true, + }, + { + desc: "complex multi-type union - number", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "complex multi-type union - string", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "hello", + want: "hello", + }, + { + desc: "integer string union - integer string becomes integer", + paramType: api.PropertyType{"integer", "string"}, + raw: "123", + want: 123, + }, + { + desc: "string integer union - integer string becomes integer (precedence)", + paramType: api.PropertyType{"string", "integer"}, + raw: "123", + want: 123, // Integer has higher precedence than string + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := parseValue(tc.raw, tc.paramType) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want) + } + }) + } +} + +func TestQwenXMLTransform(t *testing.T) { + cases := []struct { + desc string + raw string + want string + }{ + { + desc: "simple example", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + // even though quotes aren't expected in these tags, we have these tests to + // make sure they're escaped so they don't blow up the xml parser in case + // they happen + { + desc: "names with quotes", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + { + desc: "ampersands in parameter values", + raw: ` + + San Francisco & San Jose + + `, + want: ` + + San Francisco & San Jose + + `, + }, + } + + for _, tc := range cases { + got := transformToXML(tc.raw) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) + } + } +} + +func TestTrailingWhitespaceLen(t *testing.T) { + cases := []struct { + desc string + s string + want int + }{ + {desc: "no whitespace", s: "abc", want: 0}, + {desc: "trailing whitespace", s: "abc ", want: 1}, + {desc: "trailing whitespace with newlines", s: "abc \n", want: 2}, + {desc: "only whitespace", s: " \n ", want: 4}, + {desc: "leading whitespace doesn't count", s: " \n abc", want: 0}, + {desc: "unicode with trailing space", s: "测试🎯 ", want: 1}, + {desc: "unicode with trailing tab and newline", s: "مرحبا\t\n", want: 2}, + } + + for _, tc := range cases { + got := trailingWhitespaceLen(tc.s) + if got != tc.want { + t.Errorf("got %d, want %d", got, tc.want) + } + } +} + +func TestOverlapFunction(t *testing.T) { + cases := []struct { + desc string + s string + delim string + want int + }{ + {desc: "no overlap", s: "hello", delim: "", want: 5}, + {desc: "partial overlap", s: "hello", want: 3}, + {desc: "unicode with partial overlap", s: "测试🎯", want: 3}, + {desc: "unicode string with no overlap", s: "مرحبا", delim: "", want: 0}, + {desc: "unicode at boundary", s: "世界<", delim: "", want: 1}, + {desc: "unicode delimiter single rune", s: "hello🔧", delim: "🔧工具", want: len("🔧")}, + {desc: "unicode delimiter multiple runes", s: "hello🔧工", delim: "🔧工具", want: len("🔧工")}, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := overlap(tc.s, tc.delim) + if got != tc.want { + t.Errorf("overlap(%q, %q) = %d, want %d", tc.s, tc.delim, got, tc.want) + } + }) + } +} diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go new file mode 100644 index 000000000..32611791b --- /dev/null +++ b/model/renderers/qwen3coder.go @@ -0,0 +1,236 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/ollama/ollama/api" +) + +var ( + imStartTag = "<|im_start|>" + imEndTag = "<|im_end|>" +) + +// renderAdditionalKeys renders all JSON fields except the ones in handledKeys +// This follows the same approach from the reference implementation, which gives +// a particular key ordering +func renderAdditionalKeys(obj any, handledKeys map[string]bool) string { + data, err := json.Marshal(obj) + if err != nil { + return "" + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return "" + } + + var sb strings.Builder + for key, value := range m { + if handledKeys[key] { + continue + } + + // Check if value is a map or array (needs JSON serialization) + switch v := value.(type) { + case map[string]any, []any: + jsonBytes, _ := json.Marshal(v) + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonStr := string(jsonBytes) + sb.WriteString("\n<" + key + ">" + jsonStr + "") + case nil: + continue + default: + // Simple types, convert to string + sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "") + } + } + + return sb.String() +} + +func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + // filter out system messages and choose the first (if any) to win + var systemMessage string + var filteredMessages []api.Message + for _, message := range messages { + if message.Role != "system" { + filteredMessages = append(filteredMessages, message) + continue + } + + if systemMessage == "" { + systemMessage = message.Content + } + } + + if systemMessage != "" || len(tools) > 0 { + sb.WriteString(imStartTag + "system\n") + + // if we have tools but no system message, match the reference implementation by providing a default system message + if systemMessage == "" { + systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." + } + + sb.WriteString(systemMessage) + + if len(tools) > 0 { + sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n") + sb.WriteString("") + for _, tool := range tools { + sb.WriteString("\n") + sb.WriteString("\n") + sb.WriteString("" + tool.Function.Name + "") + if tool.Function.Description != "" { + sb.WriteString("\n" + tool.Function.Description + "") + } + sb.WriteString("\n") + + for name, prop := range tool.Function.Parameters.Properties { + sb.WriteString("\n") + sb.WriteString("\n" + name + "") + + if len(prop.Type) > 0 { + sb.WriteString("\n" + formatToolDefinitionType(prop.Type) + "") + } + + if prop.Description != "" { + sb.WriteString("\n" + prop.Description + "") + } + + // Render any additional keys not already handled + handledKeys := map[string]bool{ + "type": true, + "description": true, + } + sb.WriteString(renderAdditionalKeys(prop, handledKeys)) + + sb.WriteString("\n") + } + + // Render extra keys for parameters (everything except 'type' and 'properties') + paramHandledKeys := map[string]bool{ + "type": true, + "properties": true, + } + sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys)) + + sb.WriteString("\n") + sb.WriteString("\n") + } + sb.WriteString("\n") + sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n") + } + + sb.WriteString(imEndTag + "\n") + } + + for i, message := range filteredMessages { + lastMessage := i == len(filteredMessages)-1 + prefill := lastMessage && message.Role == "assistant" + switch message.Role { + case "assistant": + if len(message.ToolCalls) > 0 { + sb.WriteString(imStartTag + "assistant\n") + if message.Content != "" { + sb.WriteString(message.Content + "\n") + } + for _, toolCall := range message.ToolCalls { + sb.WriteString("\n\n") + for name, value := range toolCall.Function.Arguments { + valueStr := formatToolCallArgument(value) + sb.WriteString("\n\n" + valueStr + "\n") + } + sb.WriteString("\n\n") + } + sb.WriteString("<|im_end|>\n") + } else { + sb.WriteString(imStartTag + "assistant\n") + sb.WriteString(message.Content) + if !prefill { + sb.WriteString(imEndTag + "\n") + } + } + case "tool": + // consecutive tool responses should share a single `user`, but + // have their own tags + + // only start a new user block if this is the first tool response + if i == 0 || filteredMessages[i-1].Role != "tool" { + sb.WriteString(imStartTag + "user\n") + } + + sb.WriteString("\n") + sb.WriteString(message.Content) + sb.WriteString("\n\n") + + // close the user block only if this is the last tool response + if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" { + sb.WriteString(imEndTag + "\n") + } + default: + sb.WriteString(imStartTag + message.Role + "\n") + sb.WriteString(message.Content) + sb.WriteString(imEndTag + "\n") + } + + if lastMessage && !prefill { + sb.WriteString(imStartTag + "assistant\n") + } + } + + return sb.String(), nil +} + +func formatToolCallArgument(value any) string { + if value == nil { + return "null" + } + + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + } + + if reflect.TypeOf(value) != nil { + kind := reflect.TypeOf(value).Kind() + if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array { + if marshalled, err := json.Marshal(value); err == nil { + return string(marshalled) + } + } + } + + return fmt.Sprintf("%v", value) +} + +func formatToolDefinitionType(tp api.PropertyType) string { + if len(tp) == 0 { + return "[]" + } + + if len(tp) == 1 { + return tp[0] + } + + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonBytes, err := json.Marshal(tp) + if err != nil { + return "[]" + } + + return string(jsonBytes) +} diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go new file mode 100644 index 000000000..6a9e5eccd --- /dev/null +++ b/model/renderers/qwen3coder_test.go @@ -0,0 +1,370 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestQwen3CoderRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello, how are you?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "with tools and response", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in San Francisco?"}, + { + Role: "assistant", + Content: "I'll check the weather in San Francisco for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{ + "unit": "fahrenheit", + }, + }, + }, + }, + }, + {Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"}, + {Role: "user", Content: "That sounds nice! What about New York?"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Required: []string{"unit"}, + Properties: map[string]api.ToolProperty{ + "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"}, + // TODO(drifkin): add multiple params back once we have predictable + // order via some sort of ordered map type (see + // ) + /* + "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"}, + */ + }, + }, + }}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +get_weather +Get the current weather in a given location + + +unit +string +The unit of temperature +["celsius","fahrenheit"] + +["unit"] + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +What is the weather like in San Francisco?<|im_end|> +<|im_start|>assistant +I'll check the weather in San Francisco for you. + + + + +fahrenheit + + +<|im_end|> +<|im_start|>user + +{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12} + +<|im_end|> +<|im_start|>user +That sounds nice! What about New York?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "call double(1) and triple(2)"}, + {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}}, + {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}}, + }}, + {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"}, + {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to double"}, + }}}}, + {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"}, + }}}}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +double +Double a number + + +number +string +The number to double + + + + +triple +Triple a number + + +number +string +The number to triple + + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +call double(1) and triple(2)<|im_end|> +<|im_start|>assistant +I'll call double(1) and triple(2) for you. + + + + +1 + + + + + + +2 + + +<|im_end|> +<|im_start|>user + +{"number": 2} + + +{"number": 6} + +<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "prefill", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Tell me something interesting."}, + {Role: "assistant", Content: "I'll tell you something interesting about cats"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Tell me something interesting.<|im_end|> +<|im_start|>assistant +I'll tell you something interesting about cats`, + }, + { + name: "complex tool call arguments should remain json encoded", + msgs: []api.Message{ + {Role: "user", Content: "call tool"}, + {Role: "assistant", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "echo", + Arguments: map[string]any{ + "payload": map[string]any{"foo": "bar"}, + }, + }}, + }}, + {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"}, + }, + expected: `<|im_start|>user +call tool<|im_end|> +<|im_start|>assistant + + + + +{"foo":"bar"} + + +<|im_end|> +<|im_start|>user + +{"payload": {"foo": "bar"}} + +<|im_end|> +<|im_start|>assistant +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestFormatToolCallArgument(t *testing.T) { + tests := []struct { + name string + arg any + expected string + }{ + { + name: "string", + arg: "foo", + // notice no quotes around the string + expected: "foo", + }, + { + name: "map", + arg: map[string]any{"foo": "bar"}, + expected: "{\"foo\":\"bar\"}", + }, + { + name: "number", + arg: 1, + expected: "1", + }, + { + name: "boolean", + arg: true, + expected: "true", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolCallArgument(tt.arg) + if got != tt.expected { + t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected) + } + }) + } +} + +func TestQwen3ToolDefinitionTypes(t *testing.T) { + tests := []struct { + name string + propertyType api.PropertyType + expected string + }{ + { + name: "simple", + propertyType: api.PropertyType{"string"}, + expected: "string", + }, + { + name: "multiple", + propertyType: api.PropertyType{"string", "number"}, + expected: "[\"string\",\"number\"]", + }, + { + name: "empty", + propertyType: api.PropertyType{}, + expected: "[]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolDefinitionType(tt.propertyType) + if got != tt.expected { + t.Errorf("formatToolDefinitionType() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go new file mode 100644 index 000000000..2dfb51e49 --- /dev/null +++ b/model/renderers/renderer.go @@ -0,0 +1,26 @@ +package renderers + +import ( + "fmt" + + "github.com/ollama/ollama/api" +) + +type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error) + +func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + renderer := rendererForName(name) + if renderer == nil { + return "", fmt.Errorf("unknown renderer %q", name) + } + return renderer(msgs, tools, think) +} + +func rendererForName(name string) rendererFunc { + switch name { + case "qwen3-coder": + return Qwen3CoderRenderer + default: + return nil + } +} diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 7d725f04f..db07beee9 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -2,7 +2,6 @@ package model import ( "container/heap" - "context" "fmt" "log/slog" "strconv" @@ -13,19 +12,19 @@ import ( const spmWhitespaceSep = "▁" -type SentencePieceModel struct { +type SentencePiece struct { maxTokenLen int vocab *Vocabulary } -var _ TextProcessor = (*SentencePieceModel)(nil) +var _ TextProcessor = (*SentencePiece)(nil) -func (spm SentencePieceModel) Vocabulary() *Vocabulary { +func (spm SentencePiece) Vocabulary() *Vocabulary { return spm.vocab } -func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { - slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) +func NewSentencePiece(vocab *Vocabulary) SentencePiece { + logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} var maxTokenLen int @@ -39,21 +38,21 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { } } - slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], + logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "max token len", maxTokenLen) - return SentencePieceModel{ + return SentencePiece{ maxTokenLen: maxTokenLen, vocab: vocab, } } -func (spm SentencePieceModel) Is(id int32, special Special) bool { +func (spm SentencePiece) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { +func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { id := spm.vocab.Encode(special) @@ -182,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) } } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) - if addSpecial && len(ids) > 0 { ids = spm.vocab.addSpecials(ids) } + logutil.Trace("encoded", "string", s, "ids", ids) return ids, nil } @@ -220,7 +218,7 @@ func (q *queue) Pop() interface{} { return item } -func (spm SentencePieceModel) Decode(ids []int32) (string, error) { +func (spm SentencePiece) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { data := spm.vocab.Decode(id) @@ -246,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) + logutil.Trace("decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go index 50ac26787..8f4570c17 100644 --- a/model/sentencepiece_test.go +++ b/model/sentencepiece_test.go @@ -12,7 +12,7 @@ import ( "github.com/ollama/ollama/convert/sentencepiece" ) -func loadSentencePieceVocab(t *testing.T) SentencePieceModel { +func loadSentencePieceVocab(t *testing.T) SentencePiece { t.Helper() bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) @@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(&v) + return NewSentencePiece(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) { }) } -func TestSentencePieceModelDecodeByteTokens(t *testing.T) { +func TestSentencePieceDecodeByteTokens(t *testing.T) { vocab := &Vocabulary{ Values: []string{ "normal", @@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) { Scores: []float32{0, 0, 0, 0, 0}, } - spm := NewSentencePieceModel(vocab) + spm := NewSentencePiece(vocab) tests := []struct { name string diff --git a/model/vocabulary.go b/model/vocabulary.go index a86de58df..9b7fc789e 100644 --- a/model/vocabulary.go +++ b/model/vocabulary.go @@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 { slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) } - slog.Debug("adding bos token to prompt", "id", v.BOS) + slog.Debug("adding bos token to prompt", "id", v.BOS[0]) ids = append([]int32{v.BOS[0]}, ids...) } @@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 { slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) } - slog.Debug("adding eos token to prompt", "id", v.EOS) + slog.Debug("adding eos token to prompt", "id", v.EOS[0]) ids = append(ids, v.EOS[0]) } diff --git a/model/wordpiece.go b/model/wordpiece.go new file mode 100644 index 000000000..e8d5e848a --- /dev/null +++ b/model/wordpiece.go @@ -0,0 +1,167 @@ +package model + +import ( + "fmt" + "iter" + "strings" + "unicode" + + "github.com/ollama/ollama/logutil" +) + +type WordPiece struct { + vocab *Vocabulary +} + +// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. +// this differs from original word piece which uses "##" to indicate subwords. +const ggmlPrefix = "▁" + +var wordPieceReplacer = strings.NewReplacer( + " .", ".", + " ?", "?", + " !", "!", + " ,", ",", + " ' ", "'", + " n't", "n't", + " 'm", "'m", + " do not", " don't", + " 's", "'s", + " 've", "'ve", + " 're", "'re", +) + +// Decode implements TextProcessor. +func (wpm WordPiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for i, id := range ids { + if id < 0 || int(id) >= len(wpm.vocab.Values) { + return "", fmt.Errorf("invalid token id: %d", id) + } + + var separator string + piece := wpm.vocab.Values[id] + if i > 0 && + (strings.HasPrefix(piece, ggmlPrefix) || + (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { + separator = " " + } + + sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) + } + + return sb.String(), nil +} + +// words splits a string into words, treating CJK characters as separate words. +// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. +func (wpm WordPiece) words(s string) iter.Seq[string] { + return func(yield func(string) bool) { + runes := make([]rune, 0, len(s)*3) + for _, r := range s { + switch { + case r >= 0x4E00 && r <= 0x9FFF, + r >= 0x3400 && r <= 0x4DBF, + r >= 0x20000 && r <= 0x2A6DF, + r >= 0x2A700 && r <= 0x2B73F, + r >= 0x2B740 && r <= 0x2B81F, + r >= 0x2B820 && r <= 0x2CEAF, + r >= 0xF900 && r <= 0xFAFF, + r >= 0x2F800 && r <= 0x2FA1F: + runes = append(runes, ' ', r, ' ') + default: + runes = append(runes, r) + } + } + + for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { + // split on but keep punctuation + var start int + for start < len(w) { + end := strings.IndexFunc(w[start:], unicode.IsPunct) + if end < 0 { + end = len(w) - start + } else if end == 0 { + end = 1 + } + + if !yield(w[start : start+end]) { + return + } + + start += end + } + } + } +} + +// Encode implements TextProcessor. +func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { + var ids []int32 + + // TODO: use [UNK] from config + unk := wpm.vocab.Encode("[UNK]") + for word := range wpm.words(s) { + var start int + var pieces []int32 + for start < len(word) { + end := len(word) + + var piece int32 + for start < end { + subword := word[start:end] + if start == 0 { + subword = ggmlPrefix + subword + } + + // TODO: some models might not want [ToLower] + piece = wpm.vocab.Encode(strings.ToLower(subword)) + if piece >= 0 { + break + } + + end-- + } + + if piece < 0 { + // Unknown token + pieces = pieces[:0] + break + } + + pieces = append(pieces, piece) + start = end + } + + if len(pieces) > 0 { + ids = append(ids, pieces...) + } else { + ids = append(ids, unk) + } + } + + if addSpecial && len(ids) > 0 { + ids = wpm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +// Is implements TextProcessor. +func (wpm WordPiece) Is(id int32, special Special) bool { + return wpm.vocab.Is(id, special) +} + +// Vocabulary implements TextProcessor. +func (wpm WordPiece) Vocabulary() *Vocabulary { + return wpm.vocab +} + +var _ TextProcessor = (*WordPiece)(nil) + +func NewWordPiece(vocab *Vocabulary) WordPiece { + return WordPiece{ + vocab: vocab, + } +} diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go new file mode 100644 index 000000000..258fbffcb --- /dev/null +++ b/model/wordpiece_test.go @@ -0,0 +1,51 @@ +package model + +import ( + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWordPiece(t *testing.T) { + wpm := NewWordPiece( + &Vocabulary{ + Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, + AddBOS: true, + AddEOS: true, + BOS: []int32{1}, + EOS: []int32{2}, + }) + + ids, err := wpm.Encode("Hello world!", true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { + t.Errorf("unexpected ids (-want +got):\n%s", diff) + } + + words, err := wpm.Decode(ids) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} + +func TestWordPieceWords(t *testing.T) { + var wpm WordPiece + + basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) + if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } + + chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) + if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} diff --git a/openai/openai.go b/openai/openai.go index 13b9c425f..7ef5ac6de 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -76,8 +76,9 @@ type JsonSchema struct { } type EmbedRequest struct { - Input any `json:"input"` - Model string `json:"model"` + Input any `json:"input"` + Model string `json:"model"` + Dimensions int `json:"dimensions,omitempty"` } type StreamOptions struct { @@ -104,16 +105,18 @@ type ChatCompletionRequest struct { Tools []api.Tool `json:"tools"` Reasoning *Reasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` + DebugRenderOnly bool `json:"_debug_render_only"` } type ChatCompletion struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage,omitempty"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage,omitempty"` + DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"` } type ChatCompletionChunk struct { @@ -140,6 +143,7 @@ type CompletionRequest struct { Temperature *float32 `json:"temperature"` TopP float32 `json:"top_p"` Suffix string `json:"suffix"` + DebugRenderOnly bool `json:"_debug_render_only"` } type Completion struct { @@ -272,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), - }}, - Usage: toUsage(r), + }}, Usage: toUsage(r), + DebugInfo: r.DebugInfo, } } @@ -557,25 +561,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { var think *api.ThinkValue if r.Reasoning != nil { - options["reasoning"] = *r.Reasoning.Effort think = &api.ThinkValue{ Value: *r.Reasoning.Effort, } } else if r.ReasoningEffort != nil { - options["reasoning"] = *r.ReasoningEffort think = &api.ThinkValue{ Value: *r.ReasoningEffort, } } return &api.ChatRequest{ - Model: r.Model, - Messages: messages, - Format: format, - Options: options, - Stream: &r.Stream, - Tools: r.Tools, - Think: think, + Model: r.Model, + Messages: messages, + Format: format, + Options: options, + Stream: &r.Stream, + Tools: r.Tools, + Think: think, + DebugRenderOnly: r.DebugRenderOnly, }, nil } @@ -649,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { } return api.GenerateRequest{ - Model: r.Model, - Prompt: r.Prompt, - Options: options, - Stream: &r.Stream, - Suffix: r.Suffix, + Model: r.Model, + Prompt: r.Prompt, + Options: options, + Stream: &r.Stream, + Suffix: r.Suffix, + DebugRenderOnly: r.DebugRenderOnly, }, nil } @@ -1007,7 +1011,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc { } var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } diff --git a/parser/parser.go b/parser/parser.go index d40a79c29..c2e8f981f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) req.System = c.Args case "license": licenses = append(licenses, c.Args) + case "renderer": + req.Renderer = c.Args + case "parser": + req.Parser = c.Args case "message": role, msg, _ := strings.Cut(c.Args, ": ") messages = append(messages, api.Message{Role: role, Content: msg}) @@ -246,7 +250,7 @@ func filesForModel(path string) ([]string, error) { for _, match := range matches { if ct, err := detectContentType(match); err != nil { return nil, err - } else if ct != contentType { + } else if len(contentType) > 0 && ct != contentType { return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) } } @@ -255,7 +259,8 @@ func filesForModel(path string) ([]string, error) { } var files []string - if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { + // some safetensors files do not properly match "application/octet-stream", so skip checking their contentType + if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 { // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors files = append(files, st...) @@ -319,7 +324,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter": + case "license", "template", "system", "adapter", "renderer", "parser": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -345,7 +350,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"") ) type ParserError struct { @@ -605,7 +610,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "parameter", "message": + case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message": return true default: return false diff --git a/parser/parser_test.go b/parser/parser_test.go index 7d5a808ba..1524e890a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -198,6 +198,34 @@ BADCOMMAND param1 value1 } } +func TestParseFileRenderer(t *testing.T) { + input := ` +FROM foo +RENDERER renderer1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands) +} + +func TestParseFileParser(t *testing.T) { + input := ` +FROM foo +PARSER parser1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands) +} + func TestParseFileMessages(t *testing.T) { cases := []struct { input string diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index 2e273e69c..9ed1c2924 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -46,7 +46,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b } // Locking: Operations on InputCacheSlot (including finding one -// through LoadCacheSlot) require a lock to be be held that serializes +// through LoadCacheSlot) require a lock to be held that serializes // these operations with each other and llama.Decode type InputCacheSlot struct { @@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { targetFree = max(targetFree, 1) currentFree := c.numCtx - inputLen - discard := targetFree - currentFree - if discard < 0 { - discard = 0 - } - - return discard + return max(targetFree-currentFree, 0) } type ErrReprocessInputs struct { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 791492bbb..ae26b52bf 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -812,7 +812,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { numGPU := 0 for i := range gpuIDs { for _, layers := range req.GPULayers { - if gpuIDs[i] == layers.ID { + if gpuIDs[i] == layers.DeviceID { tensorSplit[i] = float32(len(layers.Layers)) numGPU += len(layers.Layers) } diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 8c8a29d85..a3ffc3bd2 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -34,8 +34,8 @@ type InputCache struct { func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { numCtx := kvSize / int32(numSlots) - if numCtx < 1 { - return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + if int(numCtx) < batchSize { + return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots) } slots := make([]InputCacheSlot, numSlots) @@ -70,15 +70,13 @@ func kvCacheTypeFromStr(s string) ml.DType { } func (c *InputCache) Close() { - if c == nil { - return + if c != nil && c.cache != nil { + c.cache.Close() } - - c.cache.Close() } // Locking: Operations on InputCacheSlot (including finding one -// through LoadCacheSlot) require a lock to be be held that serializes +// through LoadCacheSlot) require a lock to be held that serializes // these operations with each other and processBatch type InputCacheSlot struct { @@ -86,7 +84,7 @@ type InputCacheSlot struct { Id int // Inputs that are stored in the KV cache - Inputs []input.Input + Inputs []*input.Input // is this cache actively being processed as part of a sequence? InUse bool @@ -95,7 +93,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -113,6 +111,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp return nil, nil, err } + if !cachePrompt { + numPast = 0 + } + slot.InUse = true slot.lastUsed = time.Now() @@ -146,7 +148,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp return slot, prompt, nil } -func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) { longest := int32(-1) var longestSlot *InputCacheSlot @@ -169,7 +171,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot return longestSlot, longest, nil } -func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) { oldest := time.Now() var oldestSlot *InputCacheSlot @@ -205,7 +207,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i if longest > 0 && longestSlot != oldestSlot { slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", len(longestSlot.Inputs)) - oldestSlot.Inputs = make([]input.Input, longest) + oldestSlot.Inputs = make([]*input.Input, longest) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) if c.cache != nil { c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) @@ -215,7 +217,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i return oldestSlot, longest, nil } -func countCommonPrefix(a []input.Input, b []input.Input) int32 { +func countCommonPrefix(a []*input.Input, b []*input.Input) int32 { var count int32 for i := range a { @@ -240,17 +242,12 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { targetFree = max(targetFree, 1) currentFree := c.numCtx - inputLen - discard := targetFree - currentFree - if discard < 0 { - discard = 0 - } - - return discard + return max(targetFree-currentFree, 0) } type ErrReprocessInputs struct { - Inputs []input.Input + Inputs []*input.Input } func (e *ErrReprocessInputs) Error() string { @@ -283,13 +280,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { "id", slot.Id, "error", err) // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) - newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) + newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard)) copy(newInputs[:numKeep], slot.Inputs[:numKeep]) copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) // Reset the cache _ = c.cache.Remove(slot.Id, 0, math.MaxInt32) - slot.Inputs = []input.Input{} + slot.Inputs = []*input.Input{} // Return error with inputs that need to be reprocessed return &ErrReprocessInputs{Inputs: newInputs} diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 6897b5e46..c0693e834 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -13,50 +13,50 @@ import ( func TestCountCommon(t *testing.T) { tests := []struct { name string - t1 []input.Input - t2 []input.Input + t1 []*input.Input + t2 []*input.Input expected int32 }{ { name: "Equal", - t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, - t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 3, }, { name: "Prefix", - t1: []input.Input{{Token: 1}}, - t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []*input.Input{{Token: 1}}, + t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 1, }, { name: "Image Prefix", - t1: []input.Input{{MultimodalHash: 1}}, - t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, + t1: []*input.Input{{MultimodalHash: 1}}, + t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, + t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {MultimodalHash: 2}}, + t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}}, expected: 1, }, { name: "Empty", - t1: []input.Input{}, - t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []*input.Input{}, + t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 0, }, { name: "Both Empty", - t1: []input.Input{}, - t2: []input.Input{}, + t1: []*input.Input{}, + t2: []*input.Input{}, expected: 0, }, } @@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []input.Input + prompt []*input.Input longest expected best expected }{ @@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []input.Input{{Token: 1}}, + prompt: []*input.Input{{Token: 1}}, longest: expected{result: 0, len: 0}, best: expected{result: 0, len: 0}, }, @@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}}, + Inputs: []*input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input.Input{{Token: 1}, {Token: 2}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 2}, best: expected{result: 1, len: 2}, }, @@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []input.Input{{Token: 2}}, + prompt: []*input.Input{{Token: 2}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, }, }, - prompt: []input.Input{{Token: 1}}, + prompt: []*input.Input{{Token: 1}}, longest: expected{result: 0, len: 1}, best: expected{result: 1, len: 1}, }, @@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}}, + Inputs: []*input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input.Input{{Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 2}, {Token: 3}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{{Token: 1}}, + Inputs: []*input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input.Input{{Token: 1}, {Token: 2}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 1}, best: expected{result: 1, len: 2}, }, @@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []input.Input + prompt []*input.Input wantErr bool expectedSlotId int expectedPrompt int // expected length of remaining prompt @@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, wantErr: false, expectedSlotId: 0, expectedPrompt: 1, // Only token 3 remains @@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, wantErr: false, expectedSlotId: 0, expectedPrompt: 1, // Only token 3 remains @@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}}, wantErr: false, expectedSlotId: 0, expectedPrompt: 1, // Should leave 1 token for sampling @@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, wantErr: true, expectedSlotId: -1, expectedPrompt: -1, @@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) + slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true) // Check error state if (err != nil) != tt.wantErr { @@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) { tests := []struct { name string numCtx int32 - inputs []input.Input + inputs []*input.Input numKeep int32 cacheErr bool wantErr any @@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) { { name: "Normal shift", numCtx: 10, - inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, numKeep: 2, cacheErr: false, // No error wantErr: nil, @@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) { { name: "Cache removal fails", numCtx: 10, - inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, numKeep: 2, cacheErr: true, wantErr: &ErrReprocessInputs{}, @@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) { } slot := &InputCacheSlot{ Id: 123, - Inputs: make([]input.Input, len(tt.inputs)), + Inputs: make([]*input.Input, len(tt.inputs)), } copy(slot.Inputs, tt.inputs) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 2f41f68f2..fafd850b3 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -28,9 +28,11 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" @@ -51,10 +53,10 @@ type Sequence struct { iBatch int // prompt inputs left to evaluate - inputs []input.Input + inputs []*input.Input // inputs that have been added to a batch but not yet submitted to Forward - pendingInputs []input.Input + pendingInputs []*input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -182,8 +184,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) { - var inputs []input.Input +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) { + var inputs []*input.Input var ctxs []ml.Context var mmStore multimodalStore @@ -210,7 +212,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ } for _, t := range tokens { - inputs = append(inputs, input.Input{Token: t}) + inputs = append(inputs, &input.Input{Token: t}) } // image - decode and store @@ -243,7 +245,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ mmStore.addMultimodal(imageEmbeddings) - inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) + inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } } @@ -259,6 +261,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ return inputs, ctxs, mmStore, nil } +type batchState struct { + // id provides a counter for trace logging batches + id int + + // ctx holds the backend context used for this batch + ctx ml.Context + + // modelOutput holds the outputs from this batch + modelOutput ml.Tensor + + // batchInputs holds the input token pointers which may start as + // placeholders later filled in before calling ctx.Compute + batchInputs []*input.Input + + // batch contains the inputs for a model forward pass + batch input.Batch + + // full set of seqs at the time this batch was initiated + seqs []*Sequence + + // Signaled when this batches inputs are ready and compute can proceed + inputsReadyCh chan struct{} + + // Signaling when Compute is about to begin on this batch, and + // seqs have been updated to prepare for the next batch + computeStartedCh chan struct{} + + // Signaled when this batches outputs are complete and the next batch can proceed + outputsReadyCh chan struct{} +} + type Server struct { // modelPath is the location of the model to be loaded modelPath string @@ -290,6 +323,12 @@ type Server struct { // TODO (jmorganca): make this n_batch batchSize int + // Used to signal a hard failure during async processing which will panic the runner + hardErrCh chan error + + // Simple counter used only for trace logging batches + batchID int + // protects access to everything below this line // this is context state needed for decoding mu sync.Mutex @@ -362,33 +401,74 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { s.seqsSem.Release(1) } +// track batch state between forwardBatch, computeBatch and predictForwardBatch + func (s *Server) run(ctx context.Context) { s.ready.Wait() + supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone + + var activeBatch batchState for { select { case <-ctx.Done(): return + case err := <-s.hardErrCh: + panic(err) default: - err := s.processBatch() + var err error + activeBatch, err = s.forwardBatch(activeBatch) if err != nil { panic(err) } + + if supportsAsync { + go s.computeBatch(activeBatch) + } else { + s.computeBatch(activeBatch) + } } } } -func (s *Server) processBatch() error { +// forwardBatch will calculate a batch. +func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) { + // If we have a pending batch still processing, wait until Compute has started + // before setting up the next batch so the seqs inputs are ready to receive their + // token values and we get the correct input pointers for the batchInputs + if pendingBatch.ctx != nil { + logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id) + <-pendingBatch.computeStartedCh + logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID) + nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch + } else { + logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID) + // No pendingBatch, so the inputs will be ready in the seqs immediately + nextBatch.inputsReadyCh = make(chan struct{}, 1) + nextBatch.inputsReadyCh <- struct{}{} + } + s.mu.Lock() for s.allNil() { s.cond.Wait() // Wait until an item is added } defer s.mu.Unlock() - ctx := s.model.Backend().NewContext() - defer ctx.Close() + nextBatch.ctx = s.model.Backend().NewContext() + defer func() { + if err != nil { + nextBatch.ctx.Close() + nextBatch.ctx = nil + } + }() + nextBatch.id = s.batchID + nextBatch.seqs = append([]*Sequence{}, s.seqs...) + nextBatch.computeStartedCh = make(chan struct{}, 1) + nextBatch.outputsReadyCh = make(chan struct{}, 1) - var batchInputs []int32 + // Prepare the seqs and batch, but defer the input token values as we may not be ready yet + var batchInputs []*input.Input + var batchOutputs []int32 var batch input.Batch resumeSeq := -1 @@ -396,7 +476,6 @@ func (s *Server) processBatch() error { for range s.seqs { seqIdx = (seqIdx + 1) % len(s.seqs) seq := s.seqs[seqIdx] - if seq == nil { continue } @@ -404,12 +483,13 @@ func (s *Server) processBatch() error { // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { s.removeSequence(seqIdx, llm.DoneReasonLength) + nextBatch.seqs[seqIdx] = nil continue } if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) - seq.cache.Inputs = []input.Input{} + seq.cache.Inputs = []*input.Input{} } batchSize := s.batchSize @@ -442,25 +522,28 @@ func (s *Server) processBatch() error { break } - err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { var reprocess *ErrReprocessInputs if errors.As(err, &reprocess) { // Prepend these inputs to the sequence's inputs queue for reprocessing seq.inputs = append(reprocess.Inputs, seq.inputs...) // Skip this sequence but continue processing the rest + nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch + err = nil continue } else { - return err + return } } } - batchInputs = append(batchInputs, inp.Token) + batchInputs = append(batchInputs, seq.inputs[i]) if inp.Multimodal != nil { - mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) + var mm []input.Multimodal + mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false) if err != nil { - return err + return } batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm}) } @@ -468,10 +551,11 @@ func (s *Server) processBatch() error { batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Sequences = append(batch.Sequences, seq.cache.Id) - seq.iBatch = len(batch.Outputs) - if i+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) + seq.iBatch = len(batchOutputs) + if i+1 == len(seq.inputs) || seq.embeddingOnly { + batchOutputs = append(batchOutputs, int32(len(batchInputs)-1)) } + logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) seq.pendingInputs = append(seq.pendingInputs, inp) } @@ -485,73 +569,169 @@ func (s *Server) processBatch() error { } if len(batchInputs) == 0 { - return nil + logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID) + nextBatch.ctx.Close() + nextBatch.ctx = nil + return } + s.batchID++ - modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) + // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute + batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) + batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) + nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) if err != nil { - return fmt.Errorf("failed to decode batch: %w", err) + err = fmt.Errorf("failed to build graph: %w", err) + return + } + nextBatch.batchInputs = batchInputs + nextBatch.batch = batch + + return +} + +// Async processing of the next batch +func (s *Server) computeBatch(activeBatch batchState) { + if activeBatch.ctx == nil { + // Nothing to compute + return + } + defer activeBatch.ctx.Close() + + // Wait until inputs are ready + logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id) + <-activeBatch.inputsReadyCh + logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id) + + // Once we complete, signal the next batch of inputs are ready + // This will unblock the next computeBatch, or forwardBatch if new seqs come in + defer func() { + logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id) + activeBatch.outputsReadyCh <- struct{}{} + }() + + s.mu.Lock() + + // Gather the actual input token values now that they're ready + batchInputs := make([]int32, len(activeBatch.batchInputs)) + for i := range batchInputs { + batchInputs[i] = activeBatch.batchInputs[i].Token } - logits := modelOutput.Floats() - + // Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens + // so that forwardBatch can build a batchInputs set which will eventually contain the actual + // decoded tokens. + nextBatchTokens := make([]*input.Input, len(s.seqs)) + iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock for i, seq := range s.seqs { + iBatches[i] = -1 if seq == nil { continue } + // Skip over any newly added or skipped sequences + if activeBatch.seqs[i] == nil { + continue + } - // After calling Forward, pending inputs are now in the cache + // Detect if the sequence we're processing has already been completed and replaced + // with a new sequence + if seq != activeBatch.seqs[i] { + logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i) + continue + } + + // Pending inputs will actually be in the cache after we call Compute. + // However, we have already resolved any placeholder tokens. + // + // It's possible for incoming sequences to look at the values that we've + // added to the cache here and start relying on them before we've done + // the computation. This is OK as long as we ensure that this batch's + // computation happens before any future batch's and we never fail + // (unless we take down the whole runner). if len(seq.pendingInputs) > 0 { seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) - seq.pendingInputs = []input.Input{} + seq.pendingInputs = []*input.Input{} } // don't sample prompt processing if len(seq.inputs) != 0 { if !s.cache.enabled { - return errors.New("caching disabled but unable to fit entire input in a batch") + s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch") + s.mu.Unlock() + return } continue } seq.numPredicted++ + nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats + seq.inputs = []*input.Input{nextToken} + nextBatchTokens[i] = nextToken + iBatches[i] = seq.iBatch + } + + // At this point the seqs are ready for forwardBatch to move forward so unblock + s.mu.Unlock() + + activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs) + activeBatch.ctx.ComputeWithNotify( + func() { + logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id) + activeBatch.computeStartedCh <- struct{}{} + }, + activeBatch.modelOutput) + + outputs := activeBatch.modelOutput.Floats() + + logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id) + + s.mu.Lock() + defer s.mu.Unlock() + + logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id) + for i, seq := range s.seqs { + if seq == nil || nextBatchTokens[i] == nil { + continue + } + if seq.numPredicted == 1 { seq.startGenerationTime = time.Now() } // if done processing the prompt, generate an embedding and return if seq.embeddingOnly { - // TODO(jessegross): Embedding support - slog.Warn("generation of embedding outputs not yet supported") + seq.embedding <- outputs s.removeSequence(i, llm.DoneReasonStop) continue } // sample a token - vocabSize := len(logits) / len(batch.Outputs) - - token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) + vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0) + logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) + token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { - return fmt.Errorf("failed to sample token: %w", err) + s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) + return } + nextBatchTokens[i].Token = token + // if it's an end of sequence token, break if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { // TODO (jmorganca): we should send this back // as it's important for the /api/generate context // seq.responses <- piece - + logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i) s.removeSequence(i, llm.DoneReasonStop) continue } piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) if err != nil { - return err + s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err) + return } - seq.inputs = []input.Input{{Token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") @@ -575,6 +755,7 @@ func (s *Server) processBatch() error { if tokenTruncated || origLen == newLen { tokenLen-- } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] s.removeSequence(i, llm.DoneReasonStop) @@ -593,8 +774,6 @@ func (s *Server) processBatch() error { s.removeSequence(i, llm.DoneReasonConnectionClosed) } } - - return nil } func (s *Server) completion(w http.ResponseWriter, r *http.Request) { @@ -665,7 +844,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() s.seqsSem.Release(1) @@ -721,6 +900,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } +func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { + if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone { + http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) + return + } + + var req llm.EmbeddingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting embedding request due to client closing the connection") + } else { + http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) + if err != nil { + s.mu.Unlock() + s.seqsSem.Release(1) + http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError) + return + } + + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + s.seqsSem.Release(1) + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ + Embedding: <-seq.embedding, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ @@ -736,7 +976,10 @@ func (s *Server) reserveWorstCaseGraph() error { defer ctx.Close() var err error - inputs := make([]input.Input, s.batchSize) + inputs := make([]*input.Input, s.batchSize) + for i := range inputs { + inputs[i] = &input.Input{} + } mmStore := newMultimodalStore() // Multimodal strategy: @@ -778,8 +1021,11 @@ func (s *Server) reserveWorstCaseGraph() error { } if len(inputs) < s.batchSize { - newInputs := make([]input.Input, s.batchSize) + newInputs := make([]*input.Input, s.batchSize) copy(newInputs, inputs) + for i := len(inputs); i < s.batchSize; i++ { + newInputs[i] = &input.Input{} + } inputs = newInputs } } @@ -803,12 +1049,8 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Positions[i] = int32(i) } - batch.Outputs = make([]int32, s.parallel) - for i := range batch.Outputs { - batch.Outputs[i] = int32(i) - } - batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) + batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) cache := s.model.Config().Cache if cache != nil { @@ -843,7 +1085,12 @@ func (s *Server) allocModel( defer func() { if r := recover(); r != nil { if err, ok := r.(error); ok { - panicErr = err + var noMem ml.ErrNoMem + if errors.As(err, &noMem) { + panicErr = noMem + } else { + panic(r) + } } else { panic(r) } @@ -989,6 +1236,52 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { } } +// info is the handler called by the Ollama server to report information +// about the GPU devices in use by this runner +func (s *Server) info(w http.ResponseWriter, r *http.Request) { + s.loadMu.Lock() + defer s.loadMu.Unlock() + + w.Header().Set("Content-Type", "application/json") + + m := s.model + + if m == nil { + startLoad := time.Now() + + // Dummy load to get the backend wired up + f, err := os.CreateTemp("", "*.bin") + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + defer f.Close() + defer os.Remove(f.Name()) + + if err := ggml.WriteGGUF(f, ggml.KV{ + "general.architecture": "llama", + "tokenizer.ggml.model": "gpt2", + }, nil); err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + + m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}}) + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + slog.Debug("dummy model load took", "duration", time.Since(startLoad)) + } + + startDevices := time.Now() + infos := m.Backend().BackendDevices() + slog.Debug("gathering device infos took", "duration", time.Since(startDevices)) + if err := json.NewEncoder(w).Encode(&infos); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func Execute(args []string) error { fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") @@ -1011,6 +1304,7 @@ func Execute(args []string) error { server := &Server{ modelPath: *mpath, status: llm.ServerStatusLaunched, + hardErrCh: make(chan error, 1), } server.cond = sync.NewCond(&server.mu) @@ -1028,11 +1322,9 @@ func Execute(args []string) error { mux := http.NewServeMux() // TODO: support embeddings + mux.HandleFunc("GET /info", server.info) mux.HandleFunc("POST /load", server.load) - mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) - }) - + mux.HandleFunc("POST /embedding", server.embeddings) mux.HandleFunc("POST /completion", server.completion) mux.HandleFunc("GET /health", server.health) diff --git a/sample/samplers_test.go b/sample/samplers_test.go index b720f027c..eb10295d4 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding { merges := make([]string, 0, 1) // Only need vocab for Grammar Test return model.NewBytePairEncoding( - ``, &model.Vocabulary{ Values: tokens, Types: make([]int32, len(vocab)), diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 27f3eb9d4..4c9d31193 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -78,7 +78,7 @@ function checkEnv() { } -function buildOllama() { +function buildCPU() { mkdir -Force -path "${script:DIST_DIR}\" if ($script:ARCH -ne "arm64") { Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" @@ -90,20 +90,72 @@ function buildOllama() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component CPU --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } +} +function buildCUDA11() { + # CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible + # 19.40 is the last compiler version that works, but recent udpates are 19.43 + # So this pins to MSVC 2019 for best compatibility + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { $hashEnv = @{} Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } - if ("$script:CUDA_DIRS".Contains("v12")) { - $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }} - $env:CUDAToolkit_ROOT=$hashEnv[$v12] - write-host "Building CUDA v12 backend libraries" - & cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR + if ("$script:CUDA_DIRS".Contains("v11")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + write-host "Building CUDA v11 backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + +function buildCUDA12() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v12.8")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + write-host "Building CUDA v12 backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } + } +} + +function buildCUDA13() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v13")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + $env:CUDAToolkit_ROOT=$cuda + write-host "Building CUDA v13 backend libraries $cuda" + & cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + +function buildROCm() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { if ($env:HIP_PATH) { write-host "Building ROCm backend libraries" if (-Not (get-command -ErrorAction silent ninja)) { @@ -113,7 +165,7 @@ function buildOllama() { $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe" $env:HIP_PLATFORM="amd" $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - & cmake --fresh --preset "ROCm 6" -G Ninja ` + & cmake --fresh --preset "ROCm 6" -G Ninja -DOLLAMA_RUNNER_DIR="rocm" ` -DCMAKE_C_COMPILER=clang ` -DCMAKE_CXX_COMPILER=clang++ ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` @@ -129,6 +181,10 @@ function buildOllama() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } } +} + +function buildOllama() { + mkdir -Force -path "${script:DIST_DIR}\" write-host "Building ollama CLI" & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -223,7 +279,7 @@ function distZip() { write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") { - Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" } } @@ -236,6 +292,10 @@ function distZip() { checkEnv try { if ($($args.count) -eq 0) { + buildCPU + buildCUDA12 + buildCUDA13 + buildROCm buildOllama buildApp gatherDependencies diff --git a/scripts/env.sh b/scripts/env.sh index 65a970bdc..4f5641fd3 100644 --- a/scripts/env.sh +++ b/scripts/env.sh @@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \ --build-arg=OLLAMA_FAST_BUILD \ --build-arg=CUSTOM_CPU_FLAGS \ --build-arg=GPU_RUNNER_CPU_FLAGS \ + --build-arg=PARALLEL \ --build-arg=AMDGPU_TARGETS" echo "Building Ollama" diff --git a/server/create.go b/server/create.go index bd970876f..19f24ec80 100644 --- a/server/create.go +++ b/server/create.go @@ -10,8 +10,11 @@ import ( "io" "io/fs" "log/slog" + "net" "net/http" + "net/url" "os" + "path" "path/filepath" "slices" "strings" @@ -39,6 +42,14 @@ var ( ) func (s *Server) CreateHandler(c *gin.Context) { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + var r api.CreateRequest if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) { return } + config.Renderer = r.Renderer + config.Parser = r.Parser + for v := range r.Files { if !fs.ValidPath(v) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) @@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) { oldManifest, _ := ParseNamedManifest(name) var baseLayers []*layerGGML + var err error + var remote bool + if r.From != "" { - slog.Debug("create model from model name") + slog.Debug("create model from model name", "from", r.From) fromName := model.ParseName(r.From) if !fromName.IsValid() { ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} return } + if r.RemoteHost != "" { + ru, err := remoteURL(r.RemoteHost) + if err != nil { + ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest} + return + } - ctx, cancel := context.WithCancel(c.Request.Context()) - defer cancel() + config.RemoteModel = r.From + config.RemoteHost = ru + remote = true + } else { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() - baseLayers, err = parseFromModel(ctx, fromName, fn) - if err != nil { - ch <- gin.H{"error": err.Error()} + baseLayers, err = parseFromModel(ctx, fromName, fn) + if err != nil { + ch <- gin.H{"error": err.Error()} + } } } else if r.Files != nil { baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn) @@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) { } var adapterLayers []*layerGGML - if r.Adapters != nil { + if !remote && r.Adapters != nil { adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) if err != nil { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { @@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) { baseLayers = append(baseLayers, adapterLayers...) } - if err := createModel(r, name, baseLayers, fn); err != nil { + // Info is not currently exposed by Modelfiles, but allows overriding various + // config values + if r.Info != nil { + caps, ok := r.Info["capabilities"] + if ok { + switch tcaps := caps.(type) { + case []any: + caps := make([]string, len(tcaps)) + for i, c := range tcaps { + str, ok := c.(string) + if !ok { + continue + } + caps[i] = str + } + config.Capabilities = append(config.Capabilities, caps...) + } + } + + strFromInfo := func(k string) string { + v, ok := r.Info[k] + if ok { + val := v.(string) + return val + } + return "" + } + + vFromInfo := func(k string) float64 { + v, ok := r.Info[k] + if ok { + val := v.(float64) + return val + } + return 0 + } + + config.ModelFamily = strFromInfo("model_family") + if config.ModelFamily != "" { + config.ModelFamilies = []string{config.ModelFamily} + } + + config.BaseName = strFromInfo("base_name") + config.FileType = strFromInfo("quantization_level") + config.ModelType = strFromInfo("parameter_size") + config.ContextLen = int(vFromInfo("context_length")) + config.EmbedLen = int(vFromInfo("embedding_length")) + } + + if err := createModel(r, name, baseLayers, config, fn); err != nil { if errors.Is(err, errBadTemplate) { ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} return @@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) { streamResponse(c, ch) } +func remoteURL(raw string) (string, error) { + // Special‑case: user supplied only a path ("/foo/bar"). + if strings.HasPrefix(raw, "/") { + return (&url.URL{ + Scheme: "http", + Host: net.JoinHostPort("localhost", "11434"), + Path: path.Clean(raw), + }).String(), nil + } + + if !strings.Contains(raw, "://") { + raw = "http://" + raw + } + + if raw == "ollama.com" || raw == "http://ollama.com" { + raw = "https://ollama.com:443" + } + + u, err := url.Parse(raw) + if err != nil { + return "", fmt.Errorf("parse error: %w", err) + } + + if u.Host == "" { + u.Host = "localhost" + } + + hostPart, portPart, err := net.SplitHostPort(u.Host) + if err == nil { + u.Host = net.JoinHostPort(hostPart, portPart) + } else { + u.Host = net.JoinHostPort(u.Host, "11434") + } + + if u.Path != "" { + u.Path = path.Clean(u.Path) + } + + if u.Path == "/" { + u.Path = "" + } + + return u.String(), nil +} + func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { switch detectModelTypeFromFiles(files) { case "safetensors": @@ -316,15 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) { return ggml.KV{}, fmt.Errorf("no base model was found") } -func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) { - config := ConfigV2{ - OS: "linux", - Architecture: "amd64", - RootFS: RootFS{ - Type: "layers", - }, - } - +func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { var layers []Layer for _, layer := range baseLayers { if layer.GGML != nil { @@ -404,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, return err } - configLayer, err := createConfigLayer(layers, config) + configLayer, err := createConfigLayer(layers, *config) if err != nil { return err } diff --git a/server/create_test.go b/server/create_test.go index 59a07ff14..061efb81a 100644 --- a/server/create_test.go +++ b/server/create_test.go @@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) { }) } } + +func TestRemoteURL(t *testing.T) { + tests := []struct { + name string + input string + expected string + hasError bool + }{ + { + name: "absolute path", + input: "/foo/bar", + expected: "http://localhost:11434/foo/bar", + hasError: false, + }, + { + name: "absolute path with cleanup", + input: "/foo/../bar", + expected: "http://localhost:11434/bar", + hasError: false, + }, + { + name: "root path", + input: "/", + expected: "http://localhost:11434/", + hasError: false, + }, + { + name: "host without scheme", + input: "example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "host with port", + input: "example.com:8080", + expected: "http://example.com:8080", + hasError: false, + }, + { + name: "full URL", + input: "https://example.com:8080/path", + expected: "https://example.com:8080/path", + hasError: false, + }, + { + name: "full URL with path cleanup", + input: "https://example.com:8080/path/../other", + expected: "https://example.com:8080/other", + hasError: false, + }, + { + name: "ollama.com special case", + input: "ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "http ollama.com special case", + input: "http://ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "URL with only host", + input: "http://example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "URL with root path cleaned", + input: "http://example.com/", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "invalid URL", + input: "http://[::1]:namedport", // invalid port + expected: "", + hasError: true, + }, + { + name: "empty string", + input: "", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "host with scheme but no port", + input: "http://localhost", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "complex path cleanup", + input: "/a/b/../../c/./d", + expected: "http://localhost:11434/c/d", + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := remoteURL(tt.input) + + if tt.hasError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestRemoteURL_Idempotent(t *testing.T) { + // Test that applying remoteURL twice gives the same result as applying it once + testInputs := []string{ + "/foo/bar", + "example.com", + "https://example.com:8080/path", + "ollama.com", + "http://localhost:11434", + } + + for _, input := range testInputs { + t.Run(input, func(t *testing.T) { + firstResult, err := remoteURL(input) + if err != nil { + t.Fatalf("first call failed: %v", err) + } + + secondResult, err := remoteURL(firstResult) + if err != nil { + t.Fatalf("second call failed: %v", err) + } + + if firstResult != secondResult { + t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult) + } + }) + } +} diff --git a/server/images.go b/server/images.go index 504eb95cf..9466b7fb4 100644 --- a/server/images.go +++ b/server/images.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/gguf" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" @@ -73,29 +74,38 @@ func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} // Check for completion capability - f, err := gguf.Open(m.ModelPath) - if err == nil { - defer f.Close() + if m.ModelPath != "" { + f, err := gguf.Open(m.ModelPath) + if err == nil { + defer f.Close() - if f.KeyValue("pooling_type").Valid() { - capabilities = append(capabilities, model.CapabilityEmbedding) + if f.KeyValue("pooling_type").Valid() { + capabilities = append(capabilities, model.CapabilityEmbedding) + } else { + // If no embedding is specified, we assume the model supports completion + capabilities = append(capabilities, model.CapabilityCompletion) + } + if f.KeyValue("vision.block_count").Valid() { + capabilities = append(capabilities, model.CapabilityVision) + } } else { - // If no embedding is specified, we assume the model supports completion - capabilities = append(capabilities, model.CapabilityCompletion) + slog.Error("couldn't open model file", "error", err) } - if f.KeyValue("vision.block_count").Valid() { - capabilities = append(capabilities, model.CapabilityVision) + } else if len(m.Config.Capabilities) > 0 { + for _, c := range m.Config.Capabilities { + capabilities = append(capabilities, model.Capability(c)) } } else { - slog.Error("couldn't open model file", "error", err) + slog.Warn("unknown capabilities for model", "model", m.Name) } if m.Template == nil { return capabilities } + builtinParser := parsers.ParserForName(m.Config.Parser) // Check for tools capability - if slices.Contains(m.Template.Vars(), "tools") { + if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { capabilities = append(capabilities, model.CapabilityTools) } @@ -109,10 +119,16 @@ func (m *Model) Capabilities() []model.Capability { capabilities = append(capabilities, model.CapabilityVision) } + // Skip the thinking check if it's already set + if slices.Contains(capabilities, "thinking") { + return capabilities + } + // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) hasTags := openingTag != "" && closingTag != "" - if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) { + isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) + if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) { capabilities = append(capabilities, model.CapabilityThinking) } @@ -198,6 +214,20 @@ func (m *Model) String() string { }) } + if m.Config.Renderer != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "renderer", + Args: m.Config.Renderer, + }) + } + + if m.Config.Parser != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "parser", + Args: m.Config.Parser, + }) + } + for k, v := range m.Options { switch v := v.(type) { case []any: @@ -236,8 +266,19 @@ type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` ModelFamilies []string `json:"model_families"` - ModelType string `json:"model_type"` - FileType string `json:"file_type"` + ModelType string `json:"model_type"` // shown as Parameter Size + FileType string `json:"file_type"` // shown as Quantization Level + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + + RemoteHost string `json:"remote_host,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + + // used for remotes + Capabilities []string `json:"capabilities,omitempty"` + ContextLen int `json:"context_length,omitempty"` + EmbedLen int `json:"embedding_length,omitempty"` + BaseName string `json:"base_name,omitempty"` // required by spec Architecture string `json:"architecture"` diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go index 1f0634f7c..08b4ed7f9 100644 --- a/server/internal/internal/backoff/backoff.go +++ b/server/internal/internal/backoff/backoff.go @@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { // n^2 backoff timer is a little smoother than the // common choice of 2^n. - d := time.Duration(n*n) * 10 * time.Millisecond - if d > maxBackoff { - d = maxBackoff - } + d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) // Randomize the delay between 0.5-1.5 x msec, in order // to prevent accidental "thundering herd" problems. d = time.Duration(float64(d) * (rand.Float64() + 0.5)) diff --git a/server/prompt.go b/server/prompt.go index f1d8020ea..56bc63030 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/model/renderers" "github.com/ollama/ollama/template" ) @@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - thinkVal := false - thinkLevel := "" - if think != nil { - thinkVal = think.Bool() - thinkLevel = think.String() - } - var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) + if err != nil { return "", nil, err } - s, err := tokenize(ctx, b.String()) + s, err := tokenize(ctx, p) if err != nil { return "", nil, err } @@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } // truncate any messages that do not fit into the context window + p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think) + if err != nil { + return "", nil, err + } + + return p, images, nil +} + +func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + if m.Config.Renderer != "" { + rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) + if err != nil { + return "", err + } + return rendered, nil + } + var b bytes.Buffer thinkVal := false thinkLevel := "" @@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. thinkVal = think.Bool() thinkLevel = think.String() } - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { - return "", nil, err + if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + return "", err } - - return b.String(), images, nil + return b.String(), nil } diff --git a/server/routes.go b/server/routes.go index 60b7e3e84..7e0ba1c60 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,6 +4,7 @@ import ( "bytes" "cmp" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -15,6 +16,7 @@ import ( "net" "net/http" "net/netip" + "net/url" "os" "os/signal" "slices" @@ -28,12 +30,14 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" @@ -45,6 +49,20 @@ import ( "github.com/ollama/ollama/version" ) +const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" + +func shouldUseHarmony(model *Model) bool { + if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { + return true + } + } + + return false +} + func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -135,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return runner.llama, model, &opts, nil } +func signinURL() (string, error) { + pubKey, err := auth.GetPublicKey() + if err != nil { + return "", err + } + + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + h, _ := os.Hostname() + return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil +} + func (s *Server) GenerateHandler(c *gin.Context) { checkpointStart := time.Now() var req api.GenerateRequest @@ -175,8 +204,92 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + + if req.Template == "" && m.Template.String() != "" { + req.Template = m.Template.String() + } + + if req.Options == nil { + req.Options = map[string]any{} + } + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + // update the system prompt from the model if one isn't already specified + if req.System == "" && m.System != "" { + req.System = m.System + } + + if len(m.Messages) > 0 { + slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead") + } + + fn := func(resp api.GenerateResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Generate(c, &req, fn) + if err != nil { + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + // expire the runner - if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) c.JSON(http.StatusOK, api.GenerateResponse{ @@ -194,17 +307,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(*m) && !req.Raw - var harmonyMessageHandler *HarmonyMessageHandler - var harmonyToolParser *HarmonyToolCallAccumulator - if useHarmony { - harmonyMessageHandler = NewHarmonyMessageHandler() - harmonyMessageHandler.harmonyParser.AddImplicitStart() - harmonyToolParser = harmonyMessageHandler.CreateToolParser() + var builtinParser parsers.Parser + if shouldUseHarmony(m) && m.Config.Parser == "" { + m.Config.Parser = "harmony" } - // Validate Think value: string values currently only allowed for gptoss models - if req.Think != nil && req.Think.IsString() && !useHarmony { + if !req.Raw && m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + // no tools or last message for generate endpoint + builtinParser.Init(nil, nil) + } + } + + // Validate Think value: string values currently only allowed for harmony/gptoss models + if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } @@ -316,10 +433,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -328,13 +445,16 @@ func (s *Server) GenerateHandler(c *gin.Context) { } var thinkingState *thinking.Parser - if !useHarmony { + if builtinParser == nil { openingTag, closingTag := thinking.InferTags(m.Template.Template) if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" { thinkingState = &thinking.Parser{ OpeningTag: openingTag, ClosingTag: closingTag, } + if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) { + thinkingState.AddContent(openingTag) + } } } @@ -362,11 +482,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + if builtinParser != nil { + content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } res.Response = content res.Thinking = thinking - harmonyToolParser.Add(toolContent) + if cr.Done && len(toolCalls) > 0 { + res.ToolCalls = toolCalls + } } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking @@ -378,26 +504,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { - if useHarmony { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - - res.ToolCalls = append(res.ToolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -412,7 +518,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } - if useHarmony { + if builtinParser != nil { // only send messages with meaningful content (empty messages confuse clients) if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { ch <- res @@ -475,7 +581,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { } truncate := true - if req.Truncate != nil && !*req.Truncate { truncate = false } @@ -538,11 +643,27 @@ func (s *Server) EmbedHandler(c *gin.Context) { ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) if len(tokens) > ctxLen { if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) + return + } + + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- + } + + slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) + if ctxLen <= 0 { + // return error if the truncated input would be empty or just special tokens + c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) return } tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -563,7 +684,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { if err != nil { return err } - embeddings[i] = normalize(embedding) + // TODO: this first normalization should be done by the model + embedding = normalize(embedding) + if req.Dimensions > 0 && req.Dimensions < len(embedding) { + embedding = normalize(embedding[:req.Dimensions]) + } + embeddings[i] = embedding return nil }) } @@ -589,11 +715,7 @@ func normalize(vec []float32) []float32 { sum += v * v } - norm := float32(0.0) - if sum > 0 { - norm = float32(1.0 / math.Sqrt(float64(sum))) - } - + norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12)) for i := range vec { vec[i] *= norm } @@ -908,6 +1030,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ModifiedAt: manifest.fi.ModTime(), } + if m.Config.RemoteHost != "" { + resp.RemoteHost = m.Config.RemoteHost + resp.RemoteModel = m.Config.RemoteModel + + if m.Config.ModelFamily != "" { + resp.ModelInfo = make(map[string]any) + resp.ModelInfo["general.architecture"] = m.Config.ModelFamily + + if m.Config.BaseName != "" { + resp.ModelInfo["general.basename"] = m.Config.BaseName + } + + if m.Config.ContextLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen + } + + if m.Config.EmbedLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen + } + } + } + var params []string cs := 30 for k, v := range m.Options { @@ -938,6 +1082,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprint(&sb, m.String()) resp.Modelfile = sb.String() + // skip loading tensor information if this is a remote model + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + return resp, nil + } + kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) if err != nil { return nil, err @@ -1014,11 +1163,13 @@ func (s *Server) ListHandler(c *gin.Context) { // tag should never be masked models = append(models, api.ListModelResponse{ - Model: n.DisplayShortest(), - Name: n.DisplayShortest(), - Size: m.Size(), - Digest: m.digest, - ModifiedAt: m.fi.ModTime(), + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + RemoteModel: cf.RemoteModel, + RemoteHost: cf.RemoteHost, + Size: m.Size(), + Digest: m.digest, + ModifiedAt: m.fi.ModTime(), Details: api.ModelDetails{ Format: cf.ModelFormat, Family: cf.ModelFamily, @@ -1278,6 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) + r.POST("/api/me", s.WhoamiHandler) + + r.POST("/api/signout", s.SignoutHandler) + // deprecated + r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) + // Create r.POST("/api/create", s.CreateHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler) @@ -1400,8 +1557,8 @@ func Serve(ln net.Listener) error { // At startup we retrieve GPU information so we can get log messages before loading a model // This will log warnings to the log in case we have problems with detected GPUs - gpus := discover.GetGPUInfo() - gpus.LogDetails() + gpus := discover.GPUDevices(ctx, nil) + discover.LogDetails(gpus) var totalVRAM uint64 for _, gpu := range gpus { @@ -1474,6 +1631,70 @@ func streamResponse(c *gin.Context, ch chan any) { }) } +func (s *Server) WhoamiHandler(c *gin.Context) { + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + user, err := client.Whoami(c) + if err != nil { + slog.Error(err.Error()) + } + + // user isn't signed in + if user != nil && user.Name == "" { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + + c.JSON(http.StatusOK, user) +} + +func (s *Server) SignoutHandler(c *gin.Context) { + pubKey, err := auth.GetPublicKey() + if err != nil { + slog.Error("couldn't get public key", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + err = client.Disconnect(c, encKey) + if err != nil { + var authError api.AuthorizationError + if errors.As(err, &authError) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + c.JSON(http.StatusOK, nil) +} + func (s *Server) PsHandler(c *gin.Context) { models := []api.ProcessModelResponse{} @@ -1530,21 +1751,34 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // expire the runner - if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { - model, err := GetModel(req.Model) - if err != nil { - switch { - case os.IsNotExist(err): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == errtypes.InvalidModelNameErrMsg: - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + name, err := getExistingName(name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + m, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == errtypes.InvalidModelNameErrMsg: + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - s.sched.expireRunner(model) + return + } + + // expire the runner + if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { + s.sched.expireRunner(m) c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, @@ -1556,6 +1790,83 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + if req.Options == nil { + req.Options = map[string]any{} + } + + msgs := append(m.Messages, req.Messages...) + if req.Messages[0].Role != "system" && m.System != "" { + msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) + } + msgs = filterThinkTags(msgs, m) + req.Messages = msgs + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + fn := func(resp api.ChatResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Chat(c, &req, fn) + if err != nil { + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + caps := []model.Capability{model.CapabilityCompletion} if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) @@ -1564,17 +1875,6 @@ func (s *Server) ChatHandler(c *gin.Context) { caps = append(caps, model.CapabilityThinking) } - name := model.ParseName(req.Model) - if !name.IsValid() { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - name, err := getExistingName(name) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) @@ -1603,27 +1903,23 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var harmonyMessageHandler *HarmonyMessageHandler - var harmonyToolParser *HarmonyToolCallAccumulator - - useHarmony := shouldUseHarmony(*m) + if shouldUseHarmony(m) && m.Config.Parser == "" { + m.Config.Parser = "harmony" + } + var builtinParser parsers.Parser processedTools := req.Tools - if useHarmony { - harmonyMessageHandler = NewHarmonyMessageHandler() - var lastMessage *api.Message - if len(msgs) > 0 { - lastMessage = &msgs[len(msgs)-1] - } - harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - // make a copy of tools to pass to the chat prompt. Function names may be - // renamed to be valid Harmony function names. - processedTools = make([]api.Tool, len(req.Tools)) - copy(processedTools, req.Tools) - for i, tool := range processedTools { - processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name) + if m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + // Determine last message for chat prefill + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + // Initialize parser and get processed tools + processedTools = builtinParser.Init(req.Tools, lastMessage) } } @@ -1636,10 +1932,10 @@ func (s *Server) ChatHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -1647,8 +1943,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // Validate Think value: string values currently only allowed for gptoss models - if req.Think != nil && req.Think.IsString() && !useHarmony { + // Validate Think value: string values currently only allowed for harmony/gptoss models + if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } @@ -1660,10 +1956,14 @@ func (s *Server) ChatHandler(c *gin.Context) { OpeningTag: openingTag, ClosingTag: closingTag, } + + if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) { + thinkingState.AddContent(openingTag) + } } var toolParser *tools.Parser - if len(req.Tools) > 0 && !useHarmony { + if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) { toolParser = tools.NewParser(m.Template.Template, req.Tools) } @@ -1695,30 +1995,24 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) - res.Message.Content = content - res.Message.Thinking = thinking - harmonyToolParser.Add(toolContent) + if builtinParser != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) - if r.Done { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - *toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName) - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} - } + content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) + if err != nil { + ch <- gin.H{"error": err.Error()} + return } - // only send messages with meaningful content (empty messages confuse clients) - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { + res.Message.Content = content + res.Message.Thinking = thinking + res.Message.ToolCalls = toolCalls + + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) ch <- res + } else { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) } return diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 3b3d99100..189ef0407 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -11,6 +11,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "slices" "strings" "testing" @@ -20,6 +21,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/types/model" ) var stream bool = false @@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) { }) } +func TestCreateAndShowRemoteModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + var s Server + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test", + From: "bob", + RemoteHost: "https://ollama.com", + Info: map[string]any{ + "capabilities": []string{"completion", "tools", "thinking"}, + "model_family": "gptoss", + "context_length": 131072, + "embedding_length": 2880, + "quantization_level": "MXFP4", + "parameter_size": "20.9B", + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"}) + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + var resp api.ShowResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + expectedDetails := api.ModelDetails{ + ParentModel: "", + Format: "", + Family: "gptoss", + Families: []string{"gptoss"}, + ParameterSize: "20.9B", + QuantizationLevel: "MXFP4", + } + + if !reflect.DeepEqual(resp.Details, expectedDetails) { + t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details) + } + + expectedCaps := []model.Capability{ + model.Capability("completion"), + model.Capability("tools"), + model.Capability("thinking"), + } + + if !slices.Equal(resp.Capabilities, expectedCaps) { + t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities) + } + + v, ok := resp.ModelInfo["gptoss.context_length"] + ctxlen := v.(float64) + if !ok || int(ctxlen) != 131072 { + t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen)) + } + + v, ok = resp.ModelInfo["gptoss.embedding_length"] + embedlen := v.(float64) + if !ok || int(embedlen) != 2880 { + t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen)) + } + + fmt.Printf("resp = %#v\n", resp) +} + func TestCreateLicenses(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index f04a1da99..cc3522109 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -36,8 +36,8 @@ func TestGenerateDebugRenderOnly(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.GenerateResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } @@ -229,8 +229,8 @@ func TestChatDebugRenderOnly(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index a57975f16..8385cb17b 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -74,8 +74,8 @@ func TestGenerateChat(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -618,8 +618,8 @@ func TestGenerate(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -969,3 +969,233 @@ func TestGenerate(t *testing.T) { } }) } + +func TestChatWithPromptEndingInThinkTag(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Helper to create a standard thinking test setup + setupThinkingTest := func(t *testing.T) (*mockRunner, *Server) { + mock := &mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: llm.DoneReasonStop, + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := &Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(mock), + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { + time.Sleep(time.Millisecond) + req.successCh <- &runnerRef{llama: mock} + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a model with thinking support + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + // Create model with thinking template that adds at the end + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-thinking", + Files: map[string]string{"file.gguf": digest}, + Template: `{{- range .Messages }} +{{- if eq .Role "user" }}user: {{ .Content }} +{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}{{ .Thinking }}{{ end }}{{ .Content }} +{{ end }}{{ end }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + return mock, s + } + + mock, s := setupThinkingTest(t) + + // Helper to test chat responses + testChatRequest := func(t *testing.T, name string, userContent string, modelResponse string, expectedThinking string, expectedContent string, think bool) { + t.Run(name, func(t *testing.T) { + mock.CompletionResponse = llm.CompletionResponse{ + Content: modelResponse, + Done: true, + DoneReason: llm.DoneReasonStop, + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + } + mock.CompletionFn = nil + + streamRequest := false + req := api.ChatRequest{ + Model: "test-thinking", + Messages: []api.Message{ + {Role: "user", Content: userContent}, + }, + Stream: &streamRequest, + } + if think { + req.Think = &api.ThinkValue{Value: think} + } + + w := createRequest(t, s.ChatHandler, req) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if resp.Message.Thinking != expectedThinking { + t.Errorf("expected thinking %q, got %q", expectedThinking, resp.Message.Thinking) + } + + if resp.Message.Content != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, resp.Message.Content) + } + }) + } + + // Test cases - Note: Template adds at the end, and leading whitespace after is eaten by the parser + testChatRequest(t, "basic thinking response", + "Help me solve this problem", + " Let me think about this step by step... The answer is 42.", + "Let me think about this step by step... ", + "The answer is 42.", + true) + + testChatRequest(t, "thinking with multiple sentences", + "Explain quantum computing", + " First, I need to understand the basics. Quantum bits can be in superposition. Quantum computing uses quantum mechanics principles.", + "First, I need to understand the basics. Quantum bits can be in superposition. ", + "Quantum computing uses quantum mechanics principles.", + true) + + testChatRequest(t, "no thinking content", + "What is 2+2?", + " The answer is 4.", + "", + "The answer is 4.", + true) + + testChatRequest(t, "thinking disabled but template still adds think tag", + "Simple question", + " My thoughts The answer.", + "", + " My thoughts The answer.", + false) + + // Test streaming response with template-added + t.Run("streaming with thinking", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1) + + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + defer wg.Done() + + // Verify the prompt ends with due to template + if !strings.HasSuffix(r.Prompt, "") { + t.Errorf("expected prompt to end with , got: %q", r.Prompt) + } + + // Simulate streaming chunks + responses := []llm.CompletionResponse{ + {Content: " I need to consider", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1}, + {Content: " multiple factors here...", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1}, + {Content: " Based on my analysis,", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1}, + {Content: " the solution is straightforward.", Done: true, DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, EvalDuration: 1}, + } + + for _, resp := range responses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + fn(resp) + time.Sleep(10 * time.Millisecond) + } + } + return nil + } + + think := true + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-thinking", + Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}}, + Think: &api.ThinkValue{Value: think}, + Stream: &stream, + }) + + wg.Wait() + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + // Parse streaming responses + decoder := json.NewDecoder(w.Body) + var allThinking, allContent strings.Builder + + for { + var resp api.ChatResponse + if err := decoder.Decode(&resp); err == io.EOF { + break + } else if err != nil { + t.Fatal(err) + } + allThinking.WriteString(resp.Message.Thinking) + allContent.WriteString(resp.Message.Content) + } + + // Note: Leading whitespace after is eaten by the parser + if got := allThinking.String(); got != "I need to consider multiple factors here... " { + t.Errorf("expected thinking %q, got %q", "I need to consider multiple factors here... ", got) + } + + if got := allContent.String(); got != "Based on my analysis, the solution is straightforward." { + t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got) + } + }) +} diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index b1ede4e39..caadcb872 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -274,8 +274,8 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 100 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ @@ -425,8 +425,8 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 100 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ @@ -607,8 +607,8 @@ func TestChatHarmonyParserStreaming(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ diff --git a/server/routes_test.go b/server/routes_test.go index 87b526633..bb7e2b7c1 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) { t.Fatalf("failed to create model: %v", err) } - if err := createModel(r, modelName, baseLayers, fn); err != nil { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + + if err := createModel(r, modelName, baseLayers, config, fn); err != nil { t.Fatal(err) } } diff --git a/server/sched.go b/server/sched.go index 927265fb5..cc3ff2440 100644 --- a/server/sched.go +++ b/server/sched.go @@ -21,6 +21,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/types/model" ) @@ -52,8 +53,8 @@ type Scheduler struct { loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) - getGpuFn func() discover.GpuInfoList - getCpuFn func() discover.GpuInfoList + getGpuFn func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList + getCpuFn func() discover.GpuInfo reschedDelay time.Duration } @@ -148,7 +149,12 @@ func (s *Scheduler) processPending(ctx context.Context) { s.loadedMu.Lock() runner := s.loaded[pending.model.ModelPath] loadedCount := len(s.loaded) + runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded)) + for _, r := range s.loaded { + runnersSnapshot = append(runnersSnapshot, r) + } s.loadedMu.Unlock() + if runner != nil { if runner.needsReload(ctx, pending) { slog.Debug("reloading", "runner", runner) @@ -166,9 +172,9 @@ func (s *Scheduler) processPending(ctx context.Context) { // Get a refreshed GPU list var gpus discover.GpuInfoList if pending.opts.NumGPU == 0 { - gpus = s.getCpuFn() + gpus = discover.GpuInfoList{s.getCpuFn()} } else { - gpus = s.getGpuFn() + gpus = s.getGpuFn(ctx, runnersSnapshot) } if envconfig.MaxRunners() <= 0 { @@ -343,7 +349,11 @@ func (s *Scheduler) processCompleted(ctx context.Context) { runner.refMu.Unlock() } else { slog.Debug("starting background wait for VRAM recovery", "runner", runner) - finished := runner.waitForVRAMRecovery() + runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded)) + for _, r := range s.loaded { + runnersSnapshot = append(runnersSnapshot, r) + } + finished := s.waitForVRAMRecovery(runner, runnersSnapshot) runner.unload() delete(s.loaded, runner.modelPath) s.loadedMu.Unlock() @@ -382,10 +392,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // (if any). Returns whether the scheduler needs to evict a model to make this one fit. func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool { - numParallel := req.opts.NumParallel - if numParallel < 1 { - numParallel = 1 - } + numParallel := max(int(req.opts.NumParallel), 1) // Embedding models should always be loaded with parallel=1 if req.model.CheckCapabilities(model.CapabilityCompletion) != nil { @@ -432,7 +439,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis s.loadedMu.Unlock() - err := llama.Load(req.ctx, gpus, requireFull) + gpuIDs, err := llama.Load(req.ctx, gpus, requireFull) if err != nil { if errors.Is(err, llm.ErrLoadRequiredFull) { return true @@ -451,7 +458,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis llama: llama, Options: &req.opts, sessionDuration: sessionDuration, - gpus: gpus, + gpus: gpuIDs, vramSize: llama.VRAMSize(), totalSize: llama.TotalSize(), loading: true, @@ -500,11 +507,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis } func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) { - type predKey struct { - Library string - ID string - } - predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners + predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners s.loadedMu.Lock() runners := make([]*runnerRef, 0, len(s.loaded)) for _, r := range s.loaded { @@ -515,7 +518,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) { r.refMu.Lock() if r.llama != nil { for _, gpu := range allGpus { - predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID) + predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID) } } else { slog.Warn("unexpected nil runner reference, memory prediction may be incorrect") @@ -525,7 +528,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) { // Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list for i := range allGpus { - if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok { + if p, ok := predMap[allGpus[i].DeviceID]; ok { slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory)) if p > allGpus[i].TotalMemory { // Shouldn't happen @@ -549,8 +552,8 @@ type runnerRef struct { llama llm.LlamaServer pid int - loading bool // True only during initial load, then false forever - gpus discover.GpuInfoList // Recorded at time of provisioning + loading bool // True only during initial load, then false forever + gpus []ml.DeviceID // Recorded at time of provisioning vramSize uint64 totalSize uint64 @@ -574,7 +577,6 @@ func (runner *runnerRef) unload() { runner.llama.Close() } runner.model = nil - runner.llama = nil runner.Options = nil runner.gpus = nil } @@ -621,14 +623,14 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool // a before and after GPU memory allocation. The returned channel // will be notified when we're done waiting, or have timed out and should // proceed anyway -func (runner *runnerRef) waitForVRAMRecovery() chan any { +func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.FilteredRunnerDiscovery) chan any { finished := make(chan any, 1) // CPU or Metal don't need checking, so no waiting required // windows can page VRAM, only cuda currently can report accurate used vram usage if len(runner.gpus) == 0 || - (len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) || - (runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") { + (len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "Metal")) || + (runtime.GOOS == "windows" && runner.gpus[0].Library != "CUDA") { finished <- struct{}{} slog.Debug("no need to wait for VRAM recovery", "runner", runner) return finished @@ -636,7 +638,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { start := time.Now() // Establish a baseline before we unload - gpusBefore := discover.GetGPUInfo() + gpusBefore := s.getGpuFn(context.Background(), runners) var totalMemoryBefore, freeMemoryBefore uint64 for _, gpu := range gpusBefore { totalMemoryBefore += gpu.TotalMemory @@ -654,7 +656,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { } // Query GPUs, look for free to go back up - gpusNow := discover.GetGPUInfo() + gpusNow := s.getGpuFn(context.Background(), runners) var totalMemoryNow, freeMemoryNow uint64 for _, gpu := range gpusNow { totalMemoryNow += gpu.TotalMemory @@ -681,8 +683,7 @@ func (runner *runnerRef) LogValue() slog.Value { } if len(runner.gpus) > 0 { attrs = append(attrs, - slog.String("inference", runner.gpus[0].Library), - slog.Int("devices", len(runner.gpus)), + slog.Any("inference", runner.gpus), ) } attrs = append(attrs, @@ -698,6 +699,32 @@ func (runner *runnerRef) LogValue() slog.Value { return slog.GroupValue(attrs...) } +// Implements discover.RunnerDiscovery +func (runner *runnerRef) GetPort() int { + if runner.llama != nil { + return runner.llama.GetPort() + } + return -1 +} + +func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + if runner.llama != nil { + return runner.llama.GetDeviceInfos(ctx) + } + return nil +} + +func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID { + return runner.gpus +} + +func (runner *runnerRef) HasExited() bool { + if runner.llama != nil { + return runner.llama.HasExited() + } + return true +} + type ByDurationAndName []*runnerRef func (a ByDurationAndName) Len() int { return len(a) } diff --git a/server/sched_test.go b/server/sched_test.go index 0acd59118..fd6309e33 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -17,6 +17,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/ml" ) func TestMain(m *testing.M) { @@ -61,7 +62,7 @@ func TestLoad(t *testing.T) { err := <-req.errCh require.Contains(t, err.Error(), "this model may be incompatible") - server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}} + server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}} s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { server.modelPath = model return server, nil @@ -109,7 +110,7 @@ func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f return scenario.srv, nil } -func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration) *reqBundle { +func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration, vramByGPU map[ml.DeviceID]uint64) *reqBundle { b := &reqBundle{} b.ctx, b.ctxDone = context.WithCancel(ctx) t.Helper() @@ -146,22 +147,24 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } - b.srv = &mockLlm{vramSize: vramSize, vramByGPU: map[string]uint64{"": vramSize}} + b.srv = &mockLlm{vramSize: vramSize, vramByGPU: vramByGPU} return b } -func getGpuFn() discover.GpuInfoList { - g := discover.GpuInfo{Library: "metal"} +func getGpuFn(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { + slog.Info("test getGpuFn called", "runners", runners) + g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}} g.TotalMemory = 24 * format.GigaByte g.FreeMemory = 12 * format.GigaByte return []discover.GpuInfo{g} } -func getCpuFn() discover.GpuInfoList { - g := discover.GpuInfo{Library: "cpu"} +func getCpuFn() discover.GpuInfo { + slog.Info("test getCpuFn called") + g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "cpu"}} g.TotalMemory = 32 * format.GigaByte g.FreeMemory = 26 * format.GigaByte - return []discover.GpuInfo{g} + return g } func TestRequestsSameModelSameRequest(t *testing.T) { @@ -170,8 +173,8 @@ func TestRequestsSameModelSameRequest(t *testing.T) { s := InitScheduler(ctx) s.getGpuFn = getGpuFn s.getCpuFn = getCpuFn - a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) - b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}) + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil) + b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil) b.req.model = a.req.model b.f = a.f @@ -208,13 +211,13 @@ func TestRequestsSameModelSameRequest(t *testing.T) { } func TestRequestsSimpleReloadSameModel(t *testing.T) { - ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond) + ctx, done := context.WithTimeout(t.Context(), 5000*time.Millisecond) defer done() s := InitScheduler(ctx) s.getGpuFn = getGpuFn s.getCpuFn = getCpuFn - a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) - b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}) + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil) + b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil) tmpModel := *a.req.model b.req.model = &tmpModel b.f = a.f @@ -243,6 +246,15 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) { // finish first two requests, so model can reload time.Sleep(1 * time.Millisecond) a.ctxDone() + // Report recovered VRAM usage + time.Sleep(1 * time.Millisecond) + s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { + slog.Info("XXX altered getGpuFn called") + g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 24 * format.GigaByte + return []discover.GpuInfo{g} + } select { case resp := <-b.req.successCh: require.Equal(t, resp.llama, b.srv) @@ -259,15 +271,18 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond) defer done() s := InitScheduler(ctx) - s.getGpuFn = getGpuFn - s.getCpuFn = getCpuFn + s.getGpuFn = getGpuFn // 1 metal GPU + s.getCpuFn = getCpuFn // 1 CPU // Multiple loaded models - a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil) - b := newScenarioRequest(t, ctx, "ollama-model-3b", 10*format.GigaByte, nil) - c := newScenarioRequest(t, ctx, "ollama-model-4a", 10*format.GigaByte, nil) - c.req.opts.NumGPU = 0 // CPU load, will be allowed - d := newScenarioRequest(t, ctx, "ollama-model-3c", 10*format.GigaByte, nil) // Needs prior unloaded + a := newScenarioRequest(t, ctx, "model-a-1g-gpu", 1*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 1 * format.GigaByte}) + a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} + b := newScenarioRequest(t, ctx, "model-b-10g-gpu", 10*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 10 * format.GigaByte}) + b.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} + c := newScenarioRequest(t, ctx, "model-c-10g-cpu", 10*format.GigaByte, nil, nil /* No GPU load */) + c.req.opts.NumGPU = 0 // CPU load, will be allowed + b.req.sessionDuration = &api.Duration{Duration: 10 * time.Millisecond} // longer than b to cause the scheduler to favor unloading b over c + d := newScenarioRequest(t, ctx, "model-d-10g-gpu", 13*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 13 * format.GigaByte}) // Needs prior unloaded t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1") s.newServerFn = a.newServer @@ -338,7 +353,16 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { s.loadedMu.Lock() require.Len(t, s.loaded, 2) s.loadedMu.Unlock() + // Mark b done so it can unload b.ctxDone() + // Report recovered VRAM usage so scheduler will finish waiting and unload + time.Sleep(1 * time.Millisecond) + s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { + g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 24 * format.GigaByte + return []discover.GpuInfo{g} + } select { case resp := <-d.req.successCh: require.Equal(t, resp.llama, d.srv) @@ -347,6 +371,19 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { case <-ctx.Done(): t.Fatal("timeout") } + // Wait for b to close +closeWait: + for { + select { + case <-ctx.Done(): + t.Fatal("timeout") + default: + if b.srv.closeCalled { + break closeWait + } + time.Sleep(1 * time.Millisecond) + } + } s.loadedMu.Lock() require.Len(t, s.loaded, 2) s.loadedMu.Unlock() @@ -356,9 +393,9 @@ func TestGetRunner(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 3*time.Second) defer done() - a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) - b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) - c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) + a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil) + b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil) + c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil) t.Setenv("OLLAMA_MAX_QUEUE", "1") s := InitScheduler(ctx) s.getGpuFn = getGpuFn @@ -420,7 +457,7 @@ func TestExpireRunner(t *testing.T) { var f *ggml.GGML gpus := discover.GpuInfoList{} - server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}} + server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}} s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { server.modelPath = model return server, nil @@ -458,10 +495,10 @@ func TestPrematureExpired(t *testing.T) { defer done() // Same model, same request - scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil) + scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil, nil) s := InitScheduler(ctx) - s.getGpuFn = func() discover.GpuInfoList { - g := discover.GpuInfo{Library: "metal"} + s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { + g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}} g.TotalMemory = 24 * format.GigaByte g.FreeMemory = 12 * format.GigaByte return []discover.GpuInfo{g} @@ -509,7 +546,7 @@ func TestUseLoadedRunner(t *testing.T) { sessionDuration: &api.Duration{Duration: 2}, } finished := make(chan *LlmRequest) - llm1 := &mockLlm{vramByGPU: map[string]uint64{}} + llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}} r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1} req.useLoadedRunner(r1, finished) require.Equal(t, uint(1), r1.refCount) @@ -532,22 +569,32 @@ func TestUpdateFreeSpace(t *testing.T) { defer done() gpus := discover.GpuInfoList{ { - Library: "a", - ID: "1", + DeviceID: ml.DeviceID{ + ID: "1", + }, }, { - Library: "a", - ID: "2", + DeviceID: ml.DeviceID{ + ID: "2", + }, }, } gpus[0].TotalMemory = 1000 gpus[0].FreeMemory = 900 gpus[1].TotalMemory = 2000 gpus[1].FreeMemory = 1900 - llm1 := &mockLlm{vramByGPU: map[string]uint64{"1": 50, "2": 50}} - llm2 := &mockLlm{vramByGPU: map[string]uint64{"1": 125, "2": 75}} - r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1} - r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1} + gpuIDs := []ml.DeviceID{ + { + ID: "1", + }, + { + ID: "2", + }, + } + llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 50, {ID: "2"}: 50}} + llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 125, {ID: "2"}: 75}} + r1 := &runnerRef{llama: llm1, gpus: gpuIDs, numParallel: 1} + r2 := &runnerRef{llama: llm2, gpus: gpuIDs, numParallel: 1} s := InitScheduler(ctx) s.loadedMu.Lock() @@ -584,7 +631,7 @@ func TestNeedsReload(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) defer done() - llm := &mockLlm{vramByGPU: map[string]uint64{}} + llm := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}} do := api.DefaultOptions() runner := &runnerRef{ model: &Model{ @@ -631,8 +678,8 @@ func TestUnloadAllRunners(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) defer done() - llm1 := &mockLlm{vramByGPU: map[string]uint64{}} - llm2 := &mockLlm{vramByGPU: map[string]uint64{}} + llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}} + llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}} s := InitScheduler(ctx) s.unloadAllRunners() @@ -650,7 +697,7 @@ func TestUnloadAllRunners(t *testing.T) { } func TestUnload(t *testing.T) { - llm1 := &mockLlm{vramByGPU: map[string]uint64{}} + llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}} r1 := &runnerRef{llama: llm1, numParallel: 1} r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1} r1.unload() @@ -664,7 +711,7 @@ func TestAlreadyCanceled(t *testing.T) { defer done() dctx, done2 := context.WithCancel(ctx) done2() - scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}) + scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}, nil) s := InitScheduler(ctx) slog.Info("scenario1a") s.pendingReqCh <- scenario1a.req @@ -691,24 +738,28 @@ type mockLlm struct { closeCalled bool vramSize uint64 totalSize uint64 - vramByGPU map[string]uint64 + vramByGPU map[ml.DeviceID]uint64 } func (s *mockLlm) ModelPath() string { return s.modelPath } -func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { +func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) { if requireFull { for _, g := range gpus { if g.FreeMemory >= s.vramSize { - return nil + return []ml.DeviceID{g.DeviceID}, nil } } - return llm.ErrLoadRequiredFull + return nil, llm.ErrLoadRequiredFull } - return nil + gpuIDs := make([]ml.DeviceID, len(gpus)) + for i := range gpus { + gpuIDs[i] = gpus[i].DeviceID + } + return gpuIDs, nil } func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp } func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp } @@ -732,7 +783,11 @@ func (s *mockLlm) Close() error { s.closeCalled = true return s.closeResp } -func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } -func (s *mockLlm) TotalSize() uint64 { return s.totalSize } -func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] } -func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } +func (s *mockLlm) TotalSize() uint64 { return s.totalSize } +func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } +func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) GetPort() int { return -1 } +func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } +func (s *mockLlm) HasExited() bool { return false } +func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } diff --git a/thinking/parser.go b/thinking/parser.go index a4d05e35a..bec0fb0e6 100644 --- a/thinking/parser.go +++ b/thinking/parser.go @@ -103,7 +103,9 @@ func eat(s *Parser) (string, string, bool) { // note that we use the original content, not the trimmed one because we // don't want to eat any whitespace in the real content if there were no // thinking tags - return "", s.acc.String(), false + untrimmed := s.acc.String() + s.acc.Reset() + return "", untrimmed, false } case thinkingState_ThinkingStartedEatingWhitespace: trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) diff --git a/thinking/parser_test.go b/thinking/parser_test.go index 78c297cd9..460cf3924 100644 --- a/thinking/parser_test.go +++ b/thinking/parser_test.go @@ -58,6 +58,15 @@ func TestThinkingStreaming(t *testing.T) { wantContent: " abc", wantStateAfter: thinkingState_ThinkingDone, }, + // regression test for a bug where we were transitioning directly to + // ThinkingDone without clearing the buffer. This would cuase the first + // step to be outputted twice + { + input: "def", + wantThinking: "", + wantContent: "def", + wantStateAfter: thinkingState_ThinkingDone, + }, }, }, { diff --git a/tools/tools.go b/tools/tools.go index f9ca15530..f9a2d3b9b 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -224,22 +224,45 @@ func findArguments(buffer []byte) (map[string]any, int) { return nil, 0 } + start := -1 var braces int - var start int = -1 + var inString, escaped bool + + for i := range buffer { + c := buffer[i] + + if escaped { + escaped = false + continue + } + + if c == '\\' { + escaped = true + continue + } + + if c == '"' { + inString = !inString + continue + } + + if inString { + continue + } - for i, c := range buffer { if c == '{' { if braces == 0 { start = i } braces++ - } else if c == '}' && braces > 0 { + } else if c == '}' { braces-- if braces == 0 && start != -1 { object := buffer[start : i+1] var data map[string]any if err := json.Unmarshal(object, &data); err != nil { + // not a valid object, keep looking start = -1 continue } @@ -250,9 +273,21 @@ func findArguments(buffer []byte) (map[string]any, int) { if args, ok := obj["arguments"].(map[string]any); ok { return args, true } + if argsStr, ok := obj["arguments"].(string); ok { + var argsData map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { + return argsData, ok + } + } if args, ok := obj["parameters"].(map[string]any); ok { return args, true } + if argsStr, ok := obj["parameters"].(string); ok { + var argsData map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { + return argsData, ok + } + } return nil, true } @@ -282,6 +317,10 @@ func findArguments(buffer []byte) (map[string]any, int) { return data, i } + + if braces < 0 { + braces = 0 + } } } diff --git a/tools/tools_test.go b/tools/tools_test.go index 7f00be205..288fa73c5 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -1,6 +1,7 @@ package tools import ( + "strings" "testing" "text/template" @@ -40,13 +41,7 @@ func TestParser(t *testing.T) { Function: api.ToolFunction{ Name: "get_temperature", Description: "Retrieve the temperature for a given location", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]api.ToolProperty `json:"properties"` - }{ + Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"city"}, Properties: map[string]api.ToolProperty{ @@ -68,13 +63,7 @@ func TestParser(t *testing.T) { Function: api.ToolFunction{ Name: "get_conditions", Description: "Retrieve the current weather conditions for a given location", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]api.ToolProperty `json:"properties"` - }{ + Parameters: api.ToolFunctionParameters{ Type: "object", Properties: map[string]api.ToolProperty{ "location": { @@ -104,13 +93,7 @@ func TestParser(t *testing.T) { Function: api.ToolFunction{ Name: "get_address", Description: "Get the address of a given location", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]api.ToolProperty `json:"properties"` - }{ + Parameters: api.ToolFunctionParameters{ Type: "object", Properties: map[string]api.ToolProperty{ "location": { @@ -126,13 +109,7 @@ func TestParser(t *testing.T) { Function: api.ToolFunction{ Name: "add", Description: "Add two numbers", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]api.ToolProperty `json:"properties"` - }{ + Parameters: api.ToolFunctionParameters{ Type: "object", Properties: map[string]api.ToolProperty{ "a": { @@ -1140,11 +1117,179 @@ func TestFindArguments(t *testing.T) { }, { name: "deepseek", - buffer: []byte(`", "arguments": {"location": "Tokyo"}}`), + buffer: []byte(`"arguments": {"location": "Tokyo"}}`), want: map[string]any{ "location": "Tokyo", }, }, + { + name: "string with braces", + buffer: []byte(`{"name": "process_code", "arguments": {"code": "if (x > 0) { return true; }"}}`), + want: map[string]any{ + "code": "if (x > 0) { return true; }", + }, + }, + { + name: "string with nested json", + buffer: []byte(`{"name": "send_data", "arguments": {"payload": "{\"nested\": {\"key\": \"value\"}}"}}`), + want: map[string]any{ + "payload": `{"nested": {"key": "value"}}`, + }, + }, + { + name: "string with escaped quotes and braces", + buffer: []byte(`{"name": "analyze", "arguments": {"text": "The JSON is: {\"key\": \"val{ue}\"}"}}`), + want: map[string]any{ + "text": `The JSON is: {"key": "val{ue}"}`, + }, + }, + { + name: "multiple objects with string containing braces", + buffer: []byte(`{"name": "test", "arguments": {"query": "find } in text"}} {"name": "other"}`), + want: map[string]any{ + "query": "find } in text", + }, + }, + { + name: "unmatched closing brace in string", + buffer: []byte(`{"name": "search", "arguments": {"pattern": "regex: }"}}`), + want: map[string]any{ + "pattern": "regex: }", + }, + }, + { + name: "complex nested with mixed braces", + buffer: []byte(`{"name": "analyze", "arguments": {"data": "{\"items\": [{\"value\": \"}\"}, {\"code\": \"if (x) { return y; }\"}]}"}}`), + want: map[string]any{ + "data": `{"items": [{"value": "}"}, {"code": "if (x) { return y; }"}]}`, + }, + }, + { + name: "string with newline and braces", + buffer: []byte(`{"name": "format", "arguments": {"template": "{\n \"key\": \"value\"\n}"}}`), + want: map[string]any{ + "template": "{\n \"key\": \"value\"\n}", + }, + }, + { + name: "string with unicode escape", + buffer: []byte(`{"name": "test", "arguments": {"text": "Unicode: \u007B and \u007D"}}`), + want: map[string]any{ + "text": "Unicode: { and }", + }, + }, + { + name: "array arguments", + buffer: []byte(`{"name": "batch", "arguments": ["item1", "item2", "{\"nested\": true}"]}`), + want: nil, // This should return nil because arguments is not a map + }, + { + name: "escaped backslash before quote", + buffer: []byte(`{"name": "path", "arguments": {"dir": "C:\\Program Files\\{App}\\"}}`), + want: map[string]any{ + "dir": `C:\Program Files\{App}\`, + }, + }, + { + name: "single quotes not treated as string delimiters", + buffer: []byte(`{"name": "query", "arguments": {"sql": "SELECT * FROM users WHERE name = '{admin}'"}}`), + want: map[string]any{ + "sql": "SELECT * FROM users WHERE name = '{admin}'", + }, + }, + { + name: "incomplete json at buffer end", + buffer: []byte(`{"name": "test", "arguments": {"data": "some {"`), + want: nil, + }, + { + name: "multiple escaped quotes", + buffer: []byte(`{"name": "echo", "arguments": {"msg": "He said \"Hello {World}\" loudly"}}`), + want: map[string]any{ + "msg": `He said "Hello {World}" loudly`, + }, + }, + { + name: "json with comments style string", + buffer: []byte(`{"name": "code", "arguments": {"snippet": "// This is a comment with { and }"}}`), + want: map[string]any{ + "snippet": "// This is a comment with { and }", + }, + }, + { + name: "consecutive escaped backslashes", + buffer: []byte(`{"name": "test", "arguments": {"path": "C:\\\\{folder}\\\\"}}`), + want: map[string]any{ + "path": `C:\\{folder}\\`, + }, + }, + { + name: "empty string with braces after", + buffer: []byte(`{"name": "test", "arguments": {"a": "", "b": "{value}"}}`), + want: map[string]any{ + "a": "", + "b": "{value}", + }, + }, + { + name: "unicode in key names", + buffer: []byte(`{"name": "test", "arguments": {"key{": "value", "key}": "value2"}}`), + want: map[string]any{ + "key{": "value", + "key}": "value2", + }, + }, + { + name: "very long string with braces", + buffer: []byte(`{"name": "test", "arguments": {"data": "` + strings.Repeat("a{b}c", 100) + `"}}`), + want: map[string]any{ + "data": strings.Repeat("a{b}c", 100), + }, + }, + { + name: "tab characters and braces", + buffer: []byte(`{"name": "test", "arguments": {"code": "\tif (true) {\n\t\treturn;\n\t}"}}`), + want: map[string]any{ + "code": "\tif (true) {\n\t\treturn;\n\t}", + }, + }, + { + name: "null byte in string", + buffer: []byte(`{"name": "test", "arguments": {"data": "before\u0000{after}"}}`), + want: map[string]any{ + "data": "before\x00{after}", + }, + }, + { + name: "escaped quote at end of string", + buffer: []byte(`{"name": "test", "arguments": {"data": "text with quote at end\\\""}}`), + want: map[string]any{ + "data": `text with quote at end\"`, + }, + }, + { + name: "mixed array and object in arguments", + buffer: []byte(`{"name": "test", "arguments": {"items": ["{", "}", {"key": "value"}]}}`), + want: map[string]any{ + "items": []any{"{", "}", map[string]any{"key": "value"}}, + }, + }, + { + name: "stringified arguments", + buffer: []byte(`{"name": "get_temperature", "arguments": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "stringified parameters", + buffer: []byte(`{"name": "get_temperature", "parameters": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, } for _, tt := range tests {