Merge branch 'ollama:main' into ollama-bash-lib

This commit is contained in:
Attogram Project 2025-09-18 21:51:06 +02:00 committed by GitHub
commit 8018c9d9d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 4375 additions and 825 deletions

View File

@ -65,14 +65,36 @@ jobs:
arch: amd64 arch: amd64
preset: 'CUDA 12' preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
cuda-version: '12.8' cuda-version: '12.8'
flags: '' flags: ''
runner_dir: 'cuda_v12'
- os: windows
arch: amd64
preset: 'CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runner_dir: 'cuda_v13'
- os: windows - os: windows
arch: amd64 arch: amd64
preset: 'ROCm 6' preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2' rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release environment: release
env: env:
@ -96,7 +118,7 @@ jobs:
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} $subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
} }
@ -138,7 +160,7 @@ jobs:
run: | run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
cmake --build --parallel --preset "${{ matrix.preset }}" cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env: env:
@ -232,7 +254,7 @@ jobs:
case "$COMPONENT" in case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@ -46,7 +46,7 @@ jobs:
include: include:
- preset: CPU - preset: CPU
- preset: CUDA - preset: CUDA
container: nvidia/cuda:12.8.1-devel-ubuntu22.04 container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87' flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm - preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2 container: rocm/dev-ubuntu-22.04:6.1.2
@ -78,8 +78,17 @@ jobs:
include: include:
- preset: CPU - preset: CPU
- preset: CUDA - preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80' flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
- preset: ROCm - preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
@ -102,7 +111,8 @@ jobs:
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait $subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
} }
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path

View File

@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
endif() endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama) set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama) set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
install(TARGETS ggml-cuda install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA

View File

@ -18,6 +18,14 @@
"name": "CUDA", "name": "CUDA",
"inherits": [ "Default" ] "inherits": [ "Default" ]
}, },
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
@ -26,6 +34,14 @@
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
} }
}, },
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2"
}
},
{ {
"name": "JetPack 5", "name": "JetPack 5",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
@ -72,11 +88,21 @@
"configurePreset": "CUDA", "configurePreset": "CUDA",
"targets": [ "ggml-cuda" ] "targets": [ "ggml-cuda" ]
}, },
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"configurePreset": "CUDA 12" "configurePreset": "CUDA 12"
}, },
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 13"
},
{ {
"name": "JetPack 5", "name": "JetPack 5",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],

View File

@ -39,15 +39,35 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel --preset 'CPU' \ && cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8 && cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \ cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
&& cmake --build --parallel --preset 'CUDA 12' \ && cmake --build --parallel --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
&& cmake --build --parallel --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS rocm-6 FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
@ -92,10 +112,14 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama . go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64 FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6

View File

@ -414,6 +414,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.) - [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models) - [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare) - [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
### Cloud ### Cloud

View File

@ -222,7 +222,17 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf("unmarshal: %w", err) return fmt.Errorf("unmarshal: %w", err)
} }
if response.StatusCode >= http.StatusBadRequest { if response.StatusCode == http.StatusUnauthorized {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
return AuthorizationError{
StatusCode: response.StatusCode,
Status: response.Status,
PublicKey: pubKey,
}
} else if response.StatusCode >= http.StatusBadRequest {
return StatusError{ return StatusError{
StatusCode: response.StatusCode, StatusCode: response.StatusCode,
Status: response.Status, Status: response.Status,
@ -428,3 +438,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil return version.Version, nil
} }
// Signout will disconnect an ollama instance from ollama.com
func (c *Client) Signout(ctx context.Context, encodedKey string) error {
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
}
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
var resp UserResponse
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@ -11,6 +11,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -36,6 +38,19 @@ func (e StatusError) Error() string {
} }
} }
type AuthorizationError struct {
StatusCode int
Status string
PublicKey string `json:"public_key"`
}
func (e AuthorizationError) Error() string {
if e.Status != "" {
return e.Status
}
return "something went wrong, please see the ollama server logs for details"
}
// ImageData represents the raw binary data of an image file. // ImageData represents the raw binary data of an image file.
type ImageData []byte type ImageData []byte
@ -313,12 +328,28 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"` Model string `json:"model"`
// RemoteModel is the name of the upstream model that generated the response.
RemoteModel string `json:"remote_model,omitempty"`
// RemoteHost is the URL of the upstream Ollama host that generated the response.
RemoteHost string `json:"remote_host,omitempty"`
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
// Message contains the message or part of a message from the model.
Message Message `json:"message"` Message Message `json:"message"`
// Done specifies if the response is complete.
Done bool `json:"done"`
// DoneReason is the reason the model stopped generating text.
DoneReason string `json:"done_reason,omitempty"` DoneReason string `json:"done_reason,omitempty"`
Done bool `json:"done"` DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
Metrics Metrics
} }
@ -329,13 +360,6 @@ type DebugInfo struct {
ImageCount int `json:"image_count,omitempty"` ImageCount int `json:"image_count,omitempty"`
} }
// DebugTemplateResponse is returned when _debug_render_only is set to true
type DebugTemplateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
DebugInfo DebugInfo `json:"_debug_info"`
}
type Metrics struct { type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"`
@ -388,8 +412,12 @@ type EmbedRequest struct {
// this request. // this request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Truncate truncates the input to fit the model's max sequence length.
Truncate *bool `json:"truncate,omitempty"` Truncate *bool `json:"truncate,omitempty"`
// Dimensions truncates the output embedding to the specified dimension.
Dimensions int `json:"dimensions,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
} }
@ -427,19 +455,48 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create]. // CreateRequest is the request passed to [Client.Create].
type CreateRequest struct { type CreateRequest struct {
// Model is the model name to create.
Model string `json:"model"` Model string `json:"model"`
// Stream specifies whether the response is streaming; it is true by default.
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
Quantize string `json:"quantize,omitempty"` Quantize string `json:"quantize,omitempty"`
// From is the name of the model or file to use as the source.
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
// RemoteHost is the URL of the upstream ollama API for the model (if any).
RemoteHost string `json:"remote_host,omitempty"`
// Files is a map of files include when creating the model.
Files map[string]string `json:"files,omitempty"` Files map[string]string `json:"files,omitempty"`
// Adapters is a map of LoRA adapters to include when creating the model.
Adapters map[string]string `json:"adapters,omitempty"` Adapters map[string]string `json:"adapters,omitempty"`
// Template is the template used when constructing a request to the model.
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
// License is a string or list of strings for licenses.
License any `json:"license,omitempty"` License any `json:"license,omitempty"`
// System is the system prompt for the model.
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
// Parameters is a map of hyper-parameters which are applied to the model.
Parameters map[string]any `json:"parameters,omitempty"` Parameters map[string]any `json:"parameters,omitempty"`
// Messages is a list of messages added to the model before chat and generation requests.
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
// Deprecated: set the model name with Model instead // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
// Deprecated: use Quantize instead // Deprecated: use Quantize instead
@ -476,8 +533,12 @@ type ShowResponse struct {
Parameters string `json:"parameters,omitempty"` Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"` Tensors []Tensor `json:"tensors,omitempty"`
@ -538,6 +599,8 @@ type ProcessResponse struct {
type ListModelResponse struct { type ListModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModifiedAt time.Time `json:"modified_at"` ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"` Size int64 `json:"size"`
Digest string `json:"digest"` Digest string `json:"digest"`
@ -565,6 +628,12 @@ type GenerateResponse struct {
// Model is the model name that generated the response. // Model is the model name that generated the response.
Model string `json:"model"` Model string `json:"model"`
// RemoteModel is the name of the upstream model that generated the response.
RemoteModel string `json:"remote_model,omitempty"`
// RemoteHost is the URL of the upstream Ollama host that generated the response.
RemoteHost string `json:"remote_host,omitempty"`
// CreatedAt is the timestamp of the response. // CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
@ -588,6 +657,8 @@ type GenerateResponse struct {
Metrics Metrics
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
} }
// ModelDetails provides details about a model. // ModelDetails provides details about a model.
@ -600,6 +671,18 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"` QuantizationLevel string `json:"quantization_level"`
} }
// UserResponse provides information about a user.
type UserResponse struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Bio string `json:"bio,omitempty"`
AvatarURL string `json:"avatarurl,omitempty"`
FirstName string `json:"firstname,omitempty"`
LastName string `json:"lastname,omitempty"`
Plan string `json:"plan,omitempty"`
}
// Tensor describes the metadata for a given tensor. // Tensor describes the metadata for a given tensor.
type Tensor struct { type Tensor struct {
Name string `json:"name"` Name string `json:"name"`

View File

@ -19,6 +19,19 @@ import (
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) { func keyPath() (string, error) {
fileExists := func(fp string) bool {
info, err := os.Stat(fp)
if err != nil {
return false
}
return !info.IsDir()
}
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
if fileExists(systemPath) {
return systemPath, nil
}
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return "", err

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
@ -14,6 +15,7 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@ -35,6 +37,7 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
@ -47,6 +50,8 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support // ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" { if name == "" {
@ -56,11 +61,9 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
if err != nil { if err != nil {
return return
} }
for _, cap := range resp.Capabilities { if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
if cap == model.CapabilityThinking {
return return
} }
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name) fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
} }
@ -288,7 +291,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
Think: opts.Think, Think: opts.Think,
} }
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
if r.RemoteModel != "" && opts.ShowConnect {
p.StopAndClear()
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
}
}
return nil
})
} }
func StopHandler(cmd *cobra.Command, args []string) error { func StopHandler(cmd *cobra.Command, args []string) error {
@ -312,6 +325,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
Model: args[0], Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color", WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{}, Options: map[string]any{},
ShowConnect: true,
} }
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
@ -369,6 +383,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
prompts = append([]string{string(in)}, prompts...) prompts = append([]string{string(in)}, prompts...)
opts.ShowConnect = false
opts.WordWrap = false opts.WordWrap = false
interactive = false interactive = false
} }
@ -435,6 +450,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if interactive { if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
// the server and the client both have the same public key
if pubKey == sErr.PublicKey {
h, _ := os.Hostname()
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
}
return nil
}
return err return err
} }
@ -455,6 +485,56 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts) return generate(cmd, opts)
} }
func SigninHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
user, err := client.Whoami(cmd.Context())
if err != nil {
return err
}
if user != nil && user.Name != "" {
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
fmt.Println()
return nil
}
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
return nil
}
func SignoutHandler(cmd *cobra.Command, args []string) error {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
err = client.Signout(cmd.Context(), encKey)
if err != nil {
return err
}
fmt.Println("You have signed out of ollama.com")
fmt.Println()
return nil
}
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@ -507,7 +587,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
if strings.Contains(err.Error(), "access denied") { errStr := strings.ToLower(err.Error())
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
} }
return err return err
@ -541,7 +622,14 @@ func ListHandler(cmd *cobra.Command, args []string) error {
for _, m := range models.Models { for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) var size string
if m.RemoteModel != "" {
size = "-"
} else {
size = format.HumanBytes(m.Size)
}
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
} }
} }
@ -626,8 +714,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
KeepAlive: &api.Duration{Duration: 0}, KeepAlive: &api.Duration{Duration: 0},
} }
if err := loadOrUnloadModel(cmd, opts); err != nil { if err := loadOrUnloadModel(cmd, opts); err != nil {
if !strings.Contains(err.Error(), "not found") { if !strings.Contains(strings.ToLower(err.Error()), "not found") {
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
} }
} }
@ -738,12 +826,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
} }
tableRender("Model", func() (rows [][]string) { tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if resp.ModelInfo != nil { if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string) arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch}) rows = append(rows, []string{"", "architecture", arch})
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)}) var paramStr string
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)}) if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else { } else {
rows = append(rows, []string{"", "architecture", resp.Details.Family}) rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
@ -991,6 +1103,7 @@ type runOptions struct {
KeepAlive *api.Duration KeepAlive *api.Duration
Think *api.ThinkValue Think *api.ThinkValue
HideThinking bool HideThinking bool
ShowConnect bool
} }
type displayResponseState struct { type displayResponseState struct {
@ -1546,6 +1659,22 @@ func NewCLI() *cobra.Command {
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
signinCmd := &cobra.Command{
Use: "signin",
Short: "Sign in to ollama.com",
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SigninHandler,
}
signoutCmd := &cobra.Command{
Use: "signout",
Short: "Sign out from ollama.com",
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SignoutHandler,
}
listCmd := &cobra.Command{ listCmd := &cobra.Command{
Use: "list", Use: "list",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
@ -1640,6 +1769,8 @@ func NewCLI() *cobra.Command {
stopCmd, stopCmd,
pullCmd, pullCmd,
pushCmd, pushCmd,
signinCmd,
signoutCmd,
listCmd, listCmd,
psCmd, psCmd,
copyCmd, copyCmd,

View File

@ -3,6 +3,7 @@ package cmd
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -304,6 +305,8 @@ func TestDeleteHandler(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
errPayload := `{"error":"model '%s' not found"}`
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
} }
return return
} }
@ -346,7 +349,7 @@ func TestDeleteHandler(t *testing.T) {
} }
err := DeleteHandler(cmd, []string{"test-model-not-found"}) err := DeleteHandler(cmd, []string{"test-model-not-found"})
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
} }
} }
@ -499,7 +502,7 @@ func TestPushHandler(t *testing.T) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{ err := json.NewEncoder(w).Encode(map[string]string{
"error": "access denied", "error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -522,6 +525,7 @@ func TestPushHandler(t *testing.T) {
defer mockServer.Close() defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
initializeKeypair()
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "") cmd.Flags().Bool("insecure", false, "")

View File

@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"`
normalizeEmbeddings bool
PoolingType uint32 PoolingType uint32
} }
@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string var pooling string
for _, m := range modules { for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" { switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path pooling = m.Path
break case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
} }
} }
@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert" kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType kv["bert.pooling_type"] = p.PoolingType
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)

View File

@ -96,7 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 { func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind() kind := st.tensorBase.Kind()
if st.dtype == "BF16" && kind != tensorKindFP32 { if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16 kind = tensorKindBF16
} }

View File

@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
}) })
} }
} }
func TestSafetensorKind(t *testing.T) {
tests := []struct {
name string
st safetensor
expected uint32
}{
{
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindBF16,
},
{
name: "BF16 dtype with v. prefix should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "v.weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindFP16,
},
{
name: "BF16 dtype with FP32 base kind should return FP32",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10}, // will default to FP32
},
dtype: "BF16",
},
expected: tensorKindFP32,
},
{
name: "Non-BF16 dtype should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "FP16",
},
expected: tensorKindFP16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.st.Kind()
if result != tt.expected {
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
}
})
}
}

View File

@ -16,7 +16,7 @@ import (
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK") var CudaTegra string = os.Getenv("JETSON_JETPACK")
func cudaVariant(gpuInfo CudaGPUInfo) string { func cudaVariant(gpuInfos []CudaGPUInfo) string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" { if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".") ver := strings.Split(CudaTegra, ".")
@ -43,14 +43,22 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
} }
} }
} }
return "sbsa"
} }
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers // Check GPU compute capability FIRST, lowest common denominator if multi-gpu
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { for _, gpuInfo := range gpuInfos {
// The detected driver is older than Feb 2023 if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
return "v11"
}
return "v12" return "v12"
} }
}
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
return "v12"
}
return "v13"
}

View File

@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor gpuInfo.DriverMinor = driverMinor
variant := cudaVariant(gpuInfo)
// Start with our bundled libraries
if variant != "" {
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err == nil {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
}
}
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.Variant = variant
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
unsupportedGPUs = append(unsupportedGPUs, unsupportedGPUs = append(unsupportedGPUs,
@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo) cudaGPUs = append(cudaGPUs, gpuInfo)
} }
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
variant := cudaVariant(cudaGPUs)
var variantPath string
// Start with our bundled libraries
if variant != "" {
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err != nil {
variantPath = ""
}
}
for i := range cudaGPUs {
cudaGPUs[i].Variant = variant
if variantPath != "" {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
}
}
} }
// Intel // Intel

View File

@ -1708,6 +1708,7 @@ Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding
### Examples ### Examples

View File

@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
go run . serve go run . serve
``` ```
> [!NOTE]
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
## macOS (Apple Silicon) ## macOS (Apple Silicon)
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required. macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.

View File

@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install ## Manual install
> [!NOTE] > [!NOTE]
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. > If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
Download and extract the package: Download and extract the package:
```shell ```shell
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
sudo rm -rf /usr/lib/ollama
sudo tar -C /usr -xzf ollama-linux-amd64.tgz sudo tar -C /usr -xzf ollama-linux-amd64.tgz
``` ```

View File

@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) {
return loadTimeout return loadTimeout
} }
func Remotes() []string {
var r []string
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
if raw == "" {
r = []string{"ollama.com"}
} else {
r = strings.Split(raw, ",")
}
return r
}
func Bool(k string) func() bool { func Bool(k string) func() bool {
return func() bool { return func() bool {
if s := Var(k); s != "" { if s := Var(k); s != "" {
@ -185,8 +196,6 @@ var (
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server // Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH") UseAuth = Bool("OLLAMA_AUTH")
// Enable the new memory estimation logic
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
) )
func String(s string) func() string { func String(s string) func() string {
@ -272,7 +281,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"}, "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational // Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

View File

@ -243,6 +243,7 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3", "gemma3",
"gemma3n", "gemma3n",
"mistral3", "mistral3",
"qwen3",
"llama4", "llama4",
"mllama", "mllama",
"qwen25vl", "qwen25vl",
@ -864,12 +865,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
// SupportsKVCacheType checks if the requested cache type is supported // SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool { func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {
return true
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) { if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
// gpt-oss uses attention with sinks which does not support quantized cache types // gpt-oss uses attention with sinks which does not support quantized cache types
slog.Warn("model only supports non-quantized cache types ", "mode", arch) slog.Warn("model only supports non-quantized cache types", "model", arch)
return cacheType == "f16" return false
} }
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
} }
// SupportsFlashAttention checks if the model supports flash attention // SupportsFlashAttention checks if the model supports flash attention
@ -879,6 +884,10 @@ func (f GGML) SupportsFlashAttention() bool {
return false return false
} }
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero // Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK() headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV() headCountV := f.KV().EmbeddingHeadCountV()

View File

@ -3,29 +3,15 @@ package harmony
import ( import (
"fmt" "fmt"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"unicode" "unicode"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
) )
type harmonyParserState int type harmonyParserState int
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if template.Contains("<|start|>") && template.Contains("<|end|>") {
return true
}
}
return false
}
const ( const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader harmonyParserState_ParsingHeader
@ -89,29 +75,19 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant") s.acc.WriteString("<|start|>assistant")
} }
func Prefill(lastMessage api.Message) string { func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
if lastMessage.Role != "assistant" { if lastMessage != nil && lastMessage.Role == "assistant" {
return "" // handle prefilling conditions
} if lastMessage.Content != "" {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
switch { return
case strings.TrimSpace(lastMessage.Content) != "": } else if lastMessage.Thinking != "" {
return "<|start|>assistant<|channel|>final<|message|>" s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
case strings.TrimSpace(lastMessage.Thinking) != "": return
return "<|start|>assistant<|channel|>analysis<|message|>"
default:
return ""
} }
} }
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
if strings.TrimSpace(prefillString) != "" {
s.acc.WriteString(prefillString)
} else {
s.AddImplicitStart() s.AddImplicitStart()
} }
}
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
s.lifetimeAcc.WriteString(content) s.lifetimeAcc.WriteString(content)

View File

@ -3,7 +3,6 @@ package harmony
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"testing" "testing"
) )
@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
}) })
} }
} }
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_then_content_streams", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
type step struct {
in string
wantContent string
wantThinking string
}
steps := []step{
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
{in: "<|end|>", wantThinking: ""},
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
{in: "<|end|>", wantContent: ""},
}
for i, s := range steps {
content, thinking, tool := handler.AddContent(s.in, tp)
if tool != "" {
tp.Add(tool)
}
if content != s.wantContent || thinking != s.wantThinking {
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
}
}
})
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
inputs := []string{
"<|start|>assistant<|message|>Hello",
", world",
"!<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
t.Fatalf("unexpected thinking %q", thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Hello", ", world", "!"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>analysis<|message|>Thinking...",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
got = append(got, thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Thinking...", "Answer"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
inputs := []string{
"<|chan",
"nel|>analysis<|mess",
"age|>Deep ",
"thought",
"<|end|>",
"<|start|>assistant<|message|>Done",
"<|end|>",
}
var thinkingPieces []string
var contentPieces []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
thinkingPieces = append(thinkingPieces, thinking)
}
if content != "" {
contentPieces = append(contentPieces, content)
}
}
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
}
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
}
})
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>analysis<|message|>Think",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var contentSb, thinkingSb strings.Builder
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
contentSb.WriteString(content)
thinkingSb.WriteString(thinking)
}
if contentSb.String() != "Answer" {
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
}
if thinkingSb.String() != "Think" {
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
}
})
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in, tp)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
t.Run("tool_call_across_chunks", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
"2\"}",
"<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in, tp)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
}

View File

@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: "Write me a story with a ton of emojis?", Prompt: "Write me a story in english with a lot of emojis",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,

View File

@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, { }, {
Model: smol, Model: smol,
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply", Prompt: "how do rainbows form? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, { }, {
@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{ [][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"}, {"water", "droplet", "refracted", "reflect", "color", "spectrum"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"}, {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
} }
} }

View File

@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
} }
nChunks := C.mtmd_input_chunks_size(ic) nChunks := C.mtmd_input_chunks_size(ic)
numEmbed := llamaContext.Model().NEmbd() numEmbed := llamaContext.Model().NEmbd()
lastChunkSize := 0 embed := make([][]float32, 0)
for i := range int(nChunks) { for i := range int(nChunks) {
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
lastChunkSize = numTokens slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
// Encode the chunk // Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk") return nil, errors.New("unable to encode mtmd image chunk")
} }
}
// Get the embeddings // Get the embeddings for this chunk
embed := make([][]float32, lastChunkSize) chunkEmbed := make([][]float32, numTokens)
embd := C.mtmd_get_output_embd(c.c) chunkEmbd := C.mtmd_get_output_embd(c.c)
if nil == embd { if nil == chunkEmbd {
return nil, errors.New("failed to get image embedding") continue
} }
// Extend the embedding array for each token // Extend the embedding array for each token
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
rows := make([]float32, len(s)) rows := make([]float32, len(s))
copy(rows, s) copy(rows, s)
for i := range lastChunkSize { for i := range numTokens {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed] chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
} }
embed = append(embed, chunkEmbed...)
}
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
return embed, nil return embed, nil
} }

View File

@ -202,7 +202,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
var kvct string var kvct string
if useFlashAttention { if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType()) requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && f.SupportsKVCacheType(requested) { if f.SupportsKVCacheType(requested) {
kvct = requested kvct = requested
} }
} }

View File

@ -148,7 +148,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
var textProcessor model.TextProcessor var textProcessor model.TextProcessor
var err error var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath) textProcessor, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
if err != nil { if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
@ -161,11 +165,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
} }
} }
newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates()
if newEstimates {
slog.Info("enabling new memory estimates")
}
// Verify the requested context size is <= the model training size // Verify the requested context size is <= the model training size
trainCtx := f.KV().ContextLength() trainCtx := f.KV().ContextLength()
if opts.NumCtx > int(trainCtx) && trainCtx > 0 { if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
@ -220,7 +219,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// Flash Attention also supports kv cache quantization // Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model // Enable if the requested and kv cache type is supported by the model
if kvct != "" && f.SupportsKVCacheType(kvct) { if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct loadRequest.KvCacheType = kvct
} else { } else {
slog.Warn("kv cache type not supported by model", "type", kvct) slog.Warn("kv cache type not supported by model", "type", kvct)
@ -433,7 +432,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
} }
}() }()
if newEstimates { if textProcessor != nil {
return &ollamaServer{llmServer: s}, nil return &ollamaServer{llmServer: s}, nil
} else { } else {
return &llamaServer{llmServer: s, ggml: f}, nil return &llamaServer{llmServer: s, ggml: f}, nil
@ -1350,8 +1349,6 @@ type CompletionRequest struct {
Options *api.Options Options *api.Options
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
UseHarmony bool
PrefillString string
} }
// DoneReason represents the reason why a completion response is done // DoneReason represents the reason why a completion response is done
@ -1364,8 +1361,6 @@ const (
DoneReasonLength DoneReasonLength
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
DoneReasonConnectionClosed DoneReasonConnectionClosed
// DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit
DoneReasonTokenRepeatLimit
) )
func (d DoneReason) String() string { func (d DoneReason) String() string {
@ -1374,8 +1369,6 @@ func (d DoneReason) String() string {
return "length" return "length"
case DoneReasonStop: case DoneReasonStop:
return "stop" return "stop"
case DoneReasonTokenRepeatLimit:
return "token_repeat_limit"
default: default:
return "" // closed return "" // closed
} }
@ -1383,8 +1376,6 @@ func (d DoneReason) String() string {
type CompletionResponse struct { type CompletionResponse struct {
Content string `json:"content"` Content string `json:"content"`
Thinking string `json:"thinking"`
ToolCalls []api.ToolCall `json:"tool_calls"`
DoneReason DoneReason `json:"done_reason"` DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"` Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"` PromptEvalCount int `json:"prompt_eval_count"`
@ -1508,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err) return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
} }
switch { switch {
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future case strings.TrimSpace(c.Content) == lastToken:
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
tokenRepeat++ tokenRepeat++
default: default:
lastToken = strings.TrimSpace(c.Content) lastToken = strings.TrimSpace(c.Content)
@ -1522,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err() return ctx.Err()
} }
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Done { if c.Done {
fn(c) fn(c)
return nil return nil
} }
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
fn(c)
}
} }
} }

View File

@ -5,6 +5,8 @@ import (
"io" "io"
"log/slog" "log/slog"
"path/filepath" "path/filepath"
"runtime"
"time"
) )
const LevelTrace slog.Level = -8 const LevelTrace slog.Level = -8
@ -29,10 +31,18 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
})) }))
} }
type key string
func Trace(msg string, args ...any) { func Trace(msg string, args ...any) {
slog.Log(context.TODO(), LevelTrace, msg, args...) TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...)
} }
func TraceContext(ctx context.Context, msg string, args ...any) { func TraceContext(ctx context.Context, msg string, args ...any) {
slog.Log(ctx, LevelTrace, msg, args...) if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) {
skip, _ := ctx.Value(key("skip")).(int)
pc, _, _, _ := runtime.Caller(1 + skip)
record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
record.Add(args...)
logger.Handler().Handle(ctx, record)
}
} }

View File

@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor Scale(ctx Context, s float64) Tensor
@ -429,12 +430,13 @@ type Tensor interface {
Sin(ctx Context) Tensor Sin(ctx Context) Tensor
Cos(ctx Context) Tensor Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context, up ...Tensor) Tensor
QuickGELU(ctx Context) Tensor SILU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context) Tensor RELU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor Sigmoid(ctx Context) Tensor
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor

View File

@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil { if w != nil {
@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
} }
} }
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor { func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
} }
} }
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor { func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)), t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),

View File

@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
} }
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil { if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) { if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
} }
ctx.Forward(key, value)
if cache != nil { if cache != nil {
cache.Put(ctx, key, value) cache.Put(ctx, key, value)
} }

42
ml/nn/pooling/pooling.go Normal file
View File

@ -0,0 +1,42 @@
package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
return hiddenStates
default:
panic("unknown pooling type")
}
}

View File

@ -0,0 +1,79 @@
package pooling_test
import (
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/discover"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn/pooling"
)
func setup(tb testing.TB, n int) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
if err := fsggml.WriteGGUF(f, fsggml.KV{
"general.architecture": "test",
"test.block_count": uint32(1),
}, []*fsggml.Tensor{
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
}); err != nil {
tb.Fatal(err)
}
var gpuLayers ml.GPULayersList
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
gpuLayers = append(gpuLayers, ml.GPULayers{
ID: gpus[0].ID,
Layers: slices.Collect(func(yield func(int) bool) {
for i := range n {
if !yield(i) {
return
}
}
}),
})
}
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
if err != nil {
tb.Fatal(err)
}
return b
}
func TestForward(t *testing.T) {
cases := map[pooling.Type][]float32{
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
}
for typ, want := range cases {
t.Run(typ.String(), func(t *testing.T) {
b := setup(t, 99)
defer b.Close()
ctx := b.NewContext()
defer ctx.Close()
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
tt = typ.Forward(ctx, tt)
ctx.Forward(tt).Compute(tt)
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
t.Error(diff)
}
})
}
}

View File

@ -54,10 +54,9 @@ type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs. // Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by // Outputs are the set of indicies into Inputs for which output data should
// EncodeMultimodal, along with an index into Inputs. Unused for text-only // be returned.
// models or for batches without multimodal elements. Outputs ml.Tensor
Multimodal []MultimodalIndex
// Positions is the position for each Input, relative to its sequence. Equal // Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs. // in length to Inputs.
@ -66,7 +65,8 @@ type Batch struct {
// Sequences is the sequence for each Input. Equal in length to Inputs. // Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int Sequences []int
// Outputs are the set of indicies into Inputs for which output data should // Multimodal is a set of multimodal embeddings previously created by
// be returned. // EncodeMultimodal, along with an index into Inputs. Unused for text-only
Outputs []int32 // models or for batches without multimodal elements.
Multimodal []MultimodalIndex
} }

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"math"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@ -21,10 +20,15 @@ import (
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
var ErrNoVisionModel = errors.New("this model is missing data required for image input") var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
@ -104,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
} }
arch := b.Config().Architecture() arch := b.Config().Architecture()
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 { if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed" arch = arch + "_embed"
} }
@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem() vv = vv.Elem()
} }
vv = vv.Elem() vv = reflect.Indirect(vv)
if v.IsNil() { if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem() vv = reflect.New(v.Type().Elem()).Elem()
} }

181
model/models/bert/embed.go Normal file
View File

@ -0,0 +1,181 @@
package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
// Attention
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// MLP
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := a.Query.Forward(ctx, hiddenStates)
if a.QueryNorm != nil {
query = a.QueryNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key := a.Key.Forward(ctx, hiddenStates)
if a.KeyNorm != nil {
key = a.KeyNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
poolingType pooling.Type
eps float32
normalize bool
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func New(c fs.Config) (model.Model, error) {
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
//nolint:misspell
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
// support it for compatibility.
c.Uint("tokenizer.ggml.seperator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", true),
},
}, nil
}
func init() {
model.Register("bert", New)
model.Register("bert_embed", New)
}

View File

@ -24,7 +24,7 @@ type Options struct {
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"` TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"` Layers []Layer `gguf:"blk"`
@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
attnValLen: int(c.Uint("attention.value_length")), attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0), ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0), ropeScale: c.Float("rope.scaling.factor", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"), attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"),
}, },
@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -138,7 +138,7 @@ type MLP struct {
} }
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var lastLayerOutputs ml.Tensor var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
lastLayerOutputs = outputs lastLayerOutputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)

View File

@ -1,49 +1,38 @@
package gemma3 package gemma3
import ( import (
"errors"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
type embedModel struct { type embedModel struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*TextModel *TextModel
PoolingType uint32 poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"` Dense [2]*nn.Linear `gguf:"dense"`
} }
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
for _, dense := range m.Dense { for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates) hiddenStates = dense.Forward(ctx, hiddenStates)
} }
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil return hiddenStates, nil
} }
func newEmbedModel(c fs.Config) (model.Model, error) { func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{ m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
}, },
), ),
TextModel: newTextModel(c), TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0), poolingType: pooling.Type(c.Uint("pooling_type", 0)),
} }
m.Cache = kvcache.NewWrapperCache( m.Cache = kvcache.NewWrapperCache(

View File

@ -16,7 +16,7 @@ import (
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*VisionModel `gguf:"v"` *VisionModel `gguf:"v"`
*TextModel *TextModel
@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@ -53,7 +53,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0), ropeScale: c.Float("rope.scaling.factor", 1.0),
}, },
} }
@ -84,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@ -95,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@ -123,7 +123,7 @@ type TextMLP struct {
} }
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
var lastLayerOutputs ml.Tensor var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
lastLayerOutputs = outputs lastLayerOutputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)

View File

@ -10,7 +10,7 @@ import (
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*TextModel *TextModel
} }
@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
TextModel: newTextModel(c), TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil return m.Output.Forward(ctx, hiddenStates), nil
@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.ropeBaseLocal ropeBase = m.ropeBaseLocal
} }
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
} }
type TextScaledWordEmbedding struct { type TextScaledWordEmbedding struct {
@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
} }
active = d.PerLayerInputGate.Forward(ctx, active) active = d.PerLayerInputGate.Forward(ctx, active)
active = active.GELU(ctx) active = active.GELU(ctx, perLayerInput)
active = active.Mul(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active) active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
query := attn.Query.Forward(ctx, hiddenStates) query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps) query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor var key, value ml.Tensor
if !sharedKV { if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates) key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps) key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates) value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx) hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
} }
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates) hiddenStates = hiddenStates.GELU(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates) hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates return hiddenStates
} }
@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000), ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000), ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
ropeScale: c.Float("rope.freq_scale", 1.0), ropeScale: c.Float("rope.scaling.factor", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"), activationSparsityScale: c.Floats("activation_sparsity_scale"),

View File

@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
} }
var outputs ml.Tensor var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { if i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
} }
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7) hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights) experts = experts.Mul(ctx, routingWeights)

View File

@ -2,7 +2,6 @@ package llama
import ( import (
"cmp" "cmp"
"fmt"
"math" "math"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
@ -23,30 +22,26 @@ type Options struct {
type Model struct { type Model struct {
model.Base model.Base
model.BytePairEncoding model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"` TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"` Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"` OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"` Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options Options
} }
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
// This model currently only supports the gpt2 tokenizer if c.Uint("expert_count") > 0 {
if c.String("tokenizer.ggml.model") == "llama" { // TODO: support mixtures of experts
return nil, fmt.Errorf("unsupported tokenizer: llama") return nil, model.ErrUnsupportedModel
} }
// Best effort detection of library/deepseek-coder model(s) which are incompatible
if c.String("general.name") == "deepseek-ai" { var processor model.TextProcessor
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name")) vocabulary := model.Vocabulary{
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
@ -56,18 +51,31 @@ func New(c fs.Config) (model.Model, error) {
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, }
), switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&vocabulary,
)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")), Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{ Options: Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base", 1e5),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
}, },
} }
@ -98,8 +106,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@ -108,7 +116,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
} }
type MLP struct { type MLP struct {
@ -118,7 +126,7 @@ type MLP struct {
} }
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@ -160,10 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }
func init() { func init() {

View File

@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope { if useRope {
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
} }
if opts.useQKNorm { if opts.useQKNorm {
@ -58,14 +58,14 @@ type TextMLP struct {
} }
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
type TextExperts struct { type TextExperts struct {
Gate *nn.Linear `gguf:"ffn_gate_exps"` Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"` Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"` Down *nn.LinearBatch `gguf:"ffn_down_exps"`
} }
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores) hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) upStates := e.Up.Forward(ctx, hiddenStates, experts)
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ { for i := 1; i < opts.numExpertsUsed; i++ {
@ -96,7 +96,7 @@ type TextSharedExpert struct {
} }
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel {
numExpertsUsed: int(c.Uint("expert_used_count")), numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)), noRopeInterval: int(c.Uint("no_rope_interval", 4)),
@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
} }

View File

@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
} }
func init() { func init() {

View File

@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
} }
type MLP struct { type MLP struct {
@ -65,7 +65,7 @@ type MLP struct {
} }
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
}, },
} }
} }

View File

@ -51,7 +51,7 @@ type VisionMLP struct {
} }
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }

View File

@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
} }
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
// TODO: attention mask, cross attention mask // TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
} }
func init() { func init() {

View File

@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
} }
return key, nil return key, nil
@ -58,7 +58,7 @@ type TextMLP struct {
} }
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor { func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
}, },
} }

View File

@ -1,6 +1,7 @@
package models package models
import ( import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/gemma3n"

View File

@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
value := attn.Value.Forward(ctx, hiddenStates) value := attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@ -59,7 +59,7 @@ type MLP struct {
} }
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
} }
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
}, },
} }

View File

@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
} }
func init() { func init() {

View File

@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
originalContextLength: int(c.Uint("context_length", 128000)), originalContextLength: int(c.Uint("context_length", 128000)),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
}, },
} }
@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching // Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
} }
// MLP implements the feed-forward network component with SwiGLU activation // MLP implements the feed-forward network component with SwiGLU activation
@ -90,7 +90,7 @@ type MLP struct {
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating // Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension // Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }

View File

@ -100,8 +100,7 @@ type VisionMLP struct {
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish) // Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates) gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates) hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }

View File

@ -30,10 +30,10 @@ func (o Options) headDim() int {
} }
type Attention struct { type Attention struct {
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
} }
@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = sa.QueryNorm.Forward(ctx, query, opts.eps) query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@ -66,9 +66,9 @@ type MLP interface {
type sparse struct { type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"` Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"` Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"` Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"` Down *nn.LinearBatch `gguf:"ffn_down_exps"`
} }
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.SILU(ctx)
hiddenStates = hiddenStates.Mul(ctx, upStates)
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights) experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
@ -111,7 +107,8 @@ type dense struct {
} }
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
@ -165,7 +162,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = batch.Outputs
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
@ -176,7 +173,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
} }
var _ model.Model = (*Model)(nil) var _ model.Model = (*Model)(nil)
@ -216,7 +213,7 @@ func New(c fs.Config) (model.Model, error) {
valueLength: int(c.Uint("attention.value_length")), valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")), numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")), numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true), normTopKProb: c.Bool("norm_top_k_prob", true),

37
model/parsers/parsers.go Normal file
View File

@ -0,0 +1,37 @@
package parsers
import (
"github.com/ollama/ollama/api"
)
type Parser interface {
Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
HasToolSupport() bool
HasThinkingSupport() bool
}
func ParserForName(name string) Parser {
switch name {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
case "passthrough":
return &PassthroughParser{}
default:
return nil
}
}
type PassthroughParser struct{}
func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
return s, "", nil, nil
}
func (p *PassthroughParser) HasToolSupport() bool {
return false
}
func (p *PassthroughParser) HasThinkingSupport() bool {
return false
}

410
model/parsers/qwen3coder.go Normal file
View File

@ -0,0 +1,410 @@
package parsers
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"math"
"regexp"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwenParserState int
const (
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
const (
qwenParserState_LookingForToolStart qwenParserState = iota
qwenParserState_CollectingToolContent
)
type Qwen3CoderParser struct {
state qwenParserState
acc strings.Builder
}
func (p *Qwen3CoderParser) HasToolSupport() bool {
return true
}
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false
}
func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
p.acc.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var sb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwenEventRawToolCall:
toolCall, err := parseToolCall(event, tools)
if err != nil {
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here. See the note below about
// `qwenEvent`s for more details
sb.WriteString(event.content)
}
}
return sb.String(), "", toolCalls, nil
}
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
var all []qwenEvent
keepLooping := true
for keepLooping {
var events []qwenEvent
events, keepLooping = eat(p)
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
}
return all
}
// we use some internal event types in order to communicate between `Add` and
// `eat`. We do this to support interleaving content and parallel tool calls in
// the parser, even though qwen3-coder isn't supposed to do this. Our API
// doesn't currently support models outputting multiple messages in a turn, so
// we wouldn't be able to represent it yet, but there's no reason to prevent the
// parser from supporting it, especially for future models if they end up using
// a similar format.
type qwenEvent interface {
isQwenEvent()
}
type qwenEventRawToolCall struct {
raw string
}
type qwenEventContent struct {
content string
}
func (qwenEventContent) isQwenEvent() {}
func (qwenEventRawToolCall) isQwenEvent() {}
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. If the parser transitions to another
// state, it may have additional events to emit on the next call, which is what
// the second return value indicates
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
var events []qwenEvent
switch p.state {
case qwenParserState_LookingForToolStart:
if strings.Contains(p.acc.String(), toolOpenTag) {
// we found a full tool open tag, so we can emit the content before the
// tag, being sure to trim any trailing whitespace
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
if len(before) > 0 {
events = append(events, qwenEventContent{content: before})
}
after := split[1]
p.acc.Reset()
p.acc.WriteString(after)
p.state = qwenParserState_CollectingToolContent
return events, true
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
// we found a partial tool open tag, so we can emit the unambiguous part,
// which is the (trailing-whitespace trimmed) content before the partial
// tool open tag
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.acc.String()[:ambiguousStart]
ambiguous := p.acc.String()[ambiguousStart:]
p.acc.Reset()
p.acc.WriteString(ambiguous)
events = append(events, qwenEventContent{content: unambiguous})
return events, false
} else {
// we found content that is entirely not a tool call. We should withhold
// any trailing whitespace in case this is the end of the content
whitespaceLen := trailingWhitespaceLen(p.acc.String())
ambiguousStart := len(p.acc.String()) - whitespaceLen
unambiguous := p.acc.String()[:ambiguousStart]
ambiguous := p.acc.String()[ambiguousStart:]
p.acc.Reset()
p.acc.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwenEventContent{content: unambiguous})
}
return events, false
}
case qwenParserState_CollectingToolContent:
if strings.Contains(p.acc.String(), toolCloseTag) {
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
// remove any whitespace between the tool call and any content after it
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.acc.Reset()
p.acc.WriteString(after)
events = append(events, qwenEventRawToolCall{raw: before})
p.state = qwenParserState_LookingForToolStart
return events, true
} else {
// note that we don't need to check the overlap here because we only plan
// on parsing the tool call once we see the full closing tag. We don't
// stream back the unparsed tool content, so there's no need to be eager
// here
return events, false
}
default:
panic("unreachable")
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
for i := len(s) - 1; i >= 0; i-- {
if !unicode.IsSpace(rune(s[i])) {
return len(s) - i - 1
}
}
return len(s)
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`
Parameters []XMLParameter `xml:"parameter"`
}
type XMLParameter struct {
Name string `xml:"name,attr"`
Value string `xml:",chardata"`
}
// parseToolCall parses a raw tool call string into an api.ToolCall.
// The raw string follows an xml-like format, here's an example:
//
// <function=get_current_temperature>
// <parameter=location>
// San Francisco
// </parameter>
// <parameter=unit>
// celsius
// </parameter>
// </function>
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
toolCall := api.ToolCall{}
xmlString := transformToXML(raw.raw)
var functionCall XMLFunctionCall
err := xml.Unmarshal([]byte(xmlString), &functionCall)
if err != nil {
return api.ToolCall{}, err
}
toolCall.Function = api.ToolCallFunction{
Name: functionCall.Name,
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionCall.Name {
matchedTool = &tools[i]
break
}
}
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
for _, parameter := range functionCall.Parameters {
// Look up the parameter type if we found the tool
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
paramType = prop.Type
}
}
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
}
return toolCall, nil
}
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
//
// For union types (multiple types in PropertyType, which we support but doesn't
// seem as though the reference parser does type coercion with those types in
// mind) we use a type precedence approach:
// 1. null - checked first regardless of declared types (matches reference implementation)
// 2. boolean - only "true"/"false" are valid booleans
// 3. integer - must parse as a whole number
// 4. number - must parse as numeric (returns int if no decimal part)
// 5. array - must parse as valid JSON array
// 6. object - must parse as valid JSON object
// 7. string - always succeeds (least specific type)
//
// This precedence ensures we return the most specific type that successfully parses,
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
func parseValue(raw string, paramType api.PropertyType) any {
// first remove a single leading newlines, and a single trailing newline (if
// they exist). This follows the reference implementation
raw = strings.TrimPrefix(raw, "\n")
raw = strings.TrimSuffix(raw, "\n")
// Check for null first (case-insensitive) - this takes precedence over any type
if strings.ToLower(raw) == "null" {
return nil
}
// If no type is specified, default to string
if len(paramType) == 0 {
return raw
}
// Check if any of the specified types match, using type precedence
// Order: boolean -> integer -> number -> array -> object -> string
typeSet := make(map[string]bool)
for _, t := range paramType {
typeSet[t] = true
}
// Try boolean first (most restrictive)
if typeSet["boolean"] {
lower := strings.ToLower(raw)
switch lower {
case "true":
return true
case "false":
return false
}
// If not a valid boolean but boolean is the only type, return false (matching reference)
if len(paramType) == 1 {
return false
}
// Otherwise try other types
}
// Try integer
if typeSet["integer"] {
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
// Return as int if it fits in int32, otherwise int64
if i >= math.MinInt32 && i <= math.MaxInt32 {
return int(i)
}
return i
}
// If integer is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try number (float)
if typeSet["number"] {
if f, err := strconv.ParseFloat(raw, 64); err == nil {
// If the number has no decimal part, return as int (matching reference)
if f == math.Trunc(f) {
i := int64(f)
if i >= math.MinInt32 && i <= math.MaxInt32 {
return int(i)
}
return i
}
return f
}
// If number is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try array
if typeSet["array"] {
var arr []interface{}
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
return arr
}
// If array is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try object
if typeSet["object"] {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
return obj
}
// If object is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// String always succeeds (or if "string" is in the type set)
if typeSet["string"] {
return raw
}
// If we get here, none of the types matched and string wasn't an option
// We return string as a fallback. The reference implementation will attempt
// to parse the value as a python literal, but we purposefully don't support
// that
return raw
}
var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
// xml so that it can be parsed by any xml parser
func transformToXML(raw string) string {
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
// care to properly escape the string that becomes the attribute value
return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1]
var escapedValue strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2]))
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
})
}

View File

@ -0,0 +1,830 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
// tool creates a test tool with the given name and properties
func tool(name string, props map[string]api.ToolProperty) api.Tool {
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
t.Function.Parameters.Type = "object"
t.Function.Parameters.Properties = props
return t
}
func TestQwenParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []qwenEvent
}
cases := []struct {
desc string
steps []step
only bool
}{
{
desc: "simple message streamed word by word",
steps: []step{
{
input: "hi",
wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
},
{
input: " there",
wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
},
},
},
{
desc: "content before tool call",
steps: []step{
{
input: "hi there<tool_call>",
wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
},
},
},
{
desc: "multiple tool calls in one message",
steps: []step{
{
input: "before1<tool_call>in tool call</tool_call>after1<tool_call>in tool call 2</tool_call>after2",
wantEvents: []qwenEvent{
qwenEventContent{content: "before1"},
qwenEventRawToolCall{raw: "in tool call"},
qwenEventContent{content: "after1"},
qwenEventRawToolCall{raw: "in tool call 2"},
qwenEventContent{content: "after2"},
},
},
},
},
{
desc: "tool calls with split tags",
steps: []step{
{
input: "before<tool",
wantEvents: []qwenEvent{
qwenEventContent{content: "before"},
},
},
{
input: "_call>in tool call</tool",
wantEvents: []qwenEvent{},
},
{
input: "_call>af",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "in tool call"},
qwenEventContent{content: "af"},
},
},
{
input: "ter",
wantEvents: []qwenEvent{
qwenEventContent{content: "ter"},
},
},
},
},
{
desc: "trailing whitespace between content and tool call",
steps: []step{
{
input: "abc\n<tool_call>def</tool_call>",
wantEvents: []qwenEvent{
qwenEventContent{content: "abc"},
qwenEventRawToolCall{raw: "def"},
},
},
},
},
{
desc: "trailing whitespace between tool call and content",
steps: []step{
{
input: "<tool_call>abc</tool_call>\ndef",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "abc"},
qwenEventContent{content: "def"},
},
},
},
},
{
desc: "empty content before tool call",
steps: []step{
{
input: "\n<tool_call>abc</tool_call>",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "abc"},
},
},
},
},
{
desc: "partial tool open tag fakeout",
steps: []step{
{
input: "abc\n<tool_call",
wantEvents: []qwenEvent{
// \n should not be emitted yet because `<tool_call` might be a tool
// open tag, in which case the whitespace should be trimmed
qwenEventContent{content: "abc"},
},
},
{
input: " fakeout",
wantEvents: []qwenEvent{
qwenEventContent{content: "\n<tool_call fakeout"},
},
},
},
},
{
desc: "token-by-token whitespace handling",
steps: []step{
{
input: "a",
wantEvents: []qwenEvent{
qwenEventContent{content: "a"},
},
},
{
input: "\n",
wantEvents: []qwenEvent{},
},
{
input: "b",
wantEvents: []qwenEvent{
qwenEventContent{content: "\nb"},
},
},
},
},
}
anyOnlies := false
for _, tc := range cases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range cases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3CoderParser{}
for i, step := range tc.steps {
parser.acc.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
func TestQwenToolParser(t *testing.T) {
type step struct {
name string
rawToolCall string
tools []api.Tool
wantToolCall api.ToolCall
}
steps := []step{
{
name: "simple tool call",
tools: []api.Tool{},
rawToolCall: `<function=get_current_temperature>
<parameter=location>
San Francisco
</parameter>
<parameter=unit>
celsius
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_temperature",
Arguments: map[string]any{
"location": "San Francisco",
"unit": "celsius",
},
},
},
},
{
name: "names with spaces",
tools: []api.Tool{},
rawToolCall: `<function=get current temperature>
<parameter=location with spaces>
San Francisco
</parameter>
<parameter=unit with spaces>
celsius
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get current temperature",
Arguments: map[string]any{
"location with spaces": "San Francisco",
"unit with spaces": "celsius",
},
},
},
},
// this mirrors the reference implementation's behavior, but unclear if it
// ever happens. If so, then we should probably remove them instead, this
// test is to just document the current behavior and test that we don't get
// xml errors
{
name: "names with quotes",
tools: []api.Tool{},
rawToolCall: `<function="get current temperature">
<parameter="location with spaces">
San Francisco
</parameter>
<parameter="unit with spaces">
"celsius"
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "\"get current temperature\"",
Arguments: map[string]any{
"\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"",
},
},
},
},
{
name: "tool call with typed parameters",
tools: []api.Tool{
tool("calculate", map[string]api.ToolProperty{
"x": {Type: api.PropertyType{"number"}},
"y": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `<function=calculate>
<parameter=x>
3.14
</parameter>
<parameter=y>
42
</parameter>
<parameter=enabled>
true
</parameter>
<parameter=items>
["a", "b", "c"]
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: map[string]any{
"x": 3.14,
"y": 42,
"enabled": true,
"items": []any{"a", "b", "c"},
},
},
},
},
}
for i, step := range steps {
gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err)
}
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
}
}
}
func TestQwenToolCallValueParsing(t *testing.T) {
cases := []struct {
desc string
raw string
paramType api.PropertyType
want any
}{
{
desc: "default string value (no type specified)",
paramType: api.PropertyType{},
raw: "some-string",
want: "some-string",
},
{
desc: "trim a single leading and trailing newline",
paramType: api.PropertyType{},
raw: "\nsome-string\n",
want: "some-string",
},
{
desc: "trim at most one leading and trailing newline",
paramType: api.PropertyType{},
raw: "\n\nsome-string\n\n",
want: "\nsome-string\n",
},
{
desc: "newline really has to be the first character to be trimmed",
paramType: api.PropertyType{},
raw: " \nsome-string\n ",
want: " \nsome-string\n ",
},
{
desc: "numeric type",
paramType: api.PropertyType{"number"},
raw: "123",
want: 123,
},
// Integer parsing tests
{
desc: "integer type",
paramType: api.PropertyType{"integer"},
raw: "42",
want: 42,
},
{
desc: "negative integer",
paramType: api.PropertyType{"integer"},
raw: "-100",
want: -100,
},
{
desc: "zero integer",
paramType: api.PropertyType{"integer"},
raw: "0",
want: 0,
},
{
desc: "integer with leading zeros",
paramType: api.PropertyType{"integer"},
raw: "007",
want: 7,
},
{
desc: "large integer",
paramType: api.PropertyType{"integer"},
raw: "2147483648", // Just beyond int32 max
want: int64(2147483648),
},
// Float/number parsing tests
{
desc: "float type",
paramType: api.PropertyType{"number"},
raw: "3.14",
want: 3.14,
},
{
desc: "negative float",
paramType: api.PropertyType{"number"},
raw: "-273.15",
want: -273.15,
},
{
desc: "float without decimal part",
paramType: api.PropertyType{"number"},
raw: "100.0",
want: 100,
},
{
desc: "scientific notation positive",
paramType: api.PropertyType{"number"},
raw: "1.23e5",
want: 123000, // Will be int since it has no decimal part
},
{
desc: "scientific notation negative",
paramType: api.PropertyType{"number"},
raw: "1.5e-3",
want: 0.0015,
},
{
desc: "very small float",
paramType: api.PropertyType{"number"},
raw: "0.00000001",
want: 0.00000001,
},
// String parsing tests
{
desc: "explicit string type",
paramType: api.PropertyType{"string"},
raw: "hello world",
want: "hello world",
},
{
desc: "string with special characters",
paramType: api.PropertyType{"string"},
raw: "/usr/local/bin/test-file_v2.0.sh",
want: "/usr/local/bin/test-file_v2.0.sh",
},
{
desc: "string with quotes",
paramType: api.PropertyType{"string"},
raw: `He said "hello" to me`,
want: `He said "hello" to me`,
},
{
desc: "multiline string",
paramType: api.PropertyType{"string"},
raw: "line one\nline two\nline three",
want: "line one\nline two\nline three",
},
{
desc: "empty string",
paramType: api.PropertyType{"string"},
raw: "",
want: "",
},
{
desc: "string that looks like a number",
paramType: api.PropertyType{"string"},
raw: "12345",
want: "12345",
},
// Boolean parsing tests
{
desc: "boolean true",
paramType: api.PropertyType{"boolean"},
raw: "true",
want: true,
},
{
desc: "boolean false",
paramType: api.PropertyType{"boolean"},
raw: "false",
want: false,
},
{
desc: "boolean case insensitive true",
paramType: api.PropertyType{"boolean"},
raw: "True",
want: true,
},
{
desc: "boolean case insensitive false",
paramType: api.PropertyType{"boolean"},
raw: "FALSE",
want: false,
},
// Null parsing tests
{
desc: "null value lowercase",
paramType: api.PropertyType{"string"},
raw: "null",
want: nil,
},
{
desc: "null value case insensitive",
paramType: api.PropertyType{"integer"},
raw: "NULL",
want: nil,
},
// Array parsing tests
{
desc: "array of strings",
paramType: api.PropertyType{"array"},
raw: `["foo", "bar", "baz"]`,
want: []any{"foo", "bar", "baz"},
},
{
desc: "array of numbers",
paramType: api.PropertyType{"array"},
raw: `[1, 2.5, 3]`,
want: []any{float64(1), 2.5, float64(3)},
},
{
desc: "array of mixed types",
paramType: api.PropertyType{"array"},
raw: `["string", 123, true, null]`,
want: []any{"string", float64(123), true, nil},
},
{
desc: "empty array",
paramType: api.PropertyType{"array"},
raw: `[]`,
want: []any{},
},
// Object parsing tests
{
desc: "simple object",
paramType: api.PropertyType{"object"},
raw: `{"key": "value", "number": 42}`,
want: map[string]any{"key": "value", "number": float64(42)},
},
{
desc: "nested object",
paramType: api.PropertyType{"object"},
raw: `{"outer": {"inner": "value"}}`,
want: map[string]any{"outer": map[string]any{"inner": "value"}},
},
{
desc: "empty object",
paramType: api.PropertyType{"object"},
raw: `{}`,
want: map[string]any{},
},
// Error cases and fallback behavior
{
desc: "invalid integer falls back to string",
paramType: api.PropertyType{"integer"},
raw: "not-a-number",
want: "not-a-number",
},
{
desc: "invalid float falls back to string",
paramType: api.PropertyType{"number"},
raw: "3.14.159",
want: "3.14.159",
},
{
desc: "invalid boolean falls back to false",
paramType: api.PropertyType{"boolean"},
raw: "yes",
want: false,
},
{
desc: "invalid JSON array falls back to string",
paramType: api.PropertyType{"array"},
raw: "[1, 2, unclosed",
want: "[1, 2, unclosed",
},
{
desc: "invalid JSON object falls back to string",
paramType: api.PropertyType{"object"},
raw: `{"key": unclosed`,
want: `{"key": unclosed`,
},
// Edge cases
{
desc: "integer overflow should use int64",
paramType: api.PropertyType{"integer"},
raw: "2147483648", // Beyond int32 max
want: int64(2147483648),
},
{
desc: "float with many decimal places",
paramType: api.PropertyType{"number"},
raw: "3.141592653589793",
want: 3.141592653589793,
},
{
desc: "string with JSON-like content",
paramType: api.PropertyType{"string"},
raw: `{"this": "is", "just": "a string"}`,
want: `{"this": "is", "just": "a string"}`,
},
{
desc: "whitespace-only string",
paramType: api.PropertyType{"string"},
raw: " ",
want: " ",
},
// Unknown parameter (no type specified in tools)
{
desc: "parameter not in tool definition defaults to string",
paramType: api.PropertyType{},
raw: "some value",
want: "some value",
},
// Union type tests
{
desc: "string or number union - valid number",
paramType: api.PropertyType{"string", "number"},
raw: "42.5",
want: 42.5,
},
{
desc: "string or number union - non-numeric string",
paramType: api.PropertyType{"string", "number"},
raw: "hello",
want: "hello",
},
{
desc: "number or string union - valid number (order shouldn't matter)",
paramType: api.PropertyType{"number", "string"},
raw: "42.5",
want: 42.5,
},
{
desc: "integer or null union - valid integer",
paramType: api.PropertyType{"integer", "null"},
raw: "123",
want: 123,
},
{
desc: "integer or null union - null value",
paramType: api.PropertyType{"integer", "null"},
raw: "null",
want: nil,
},
{
desc: "null or integer union - null value (order shouldn't matter)",
paramType: api.PropertyType{"null", "integer"},
raw: "null",
want: nil,
},
{
desc: "boolean or string union - valid boolean",
paramType: api.PropertyType{"boolean", "string"},
raw: "true",
want: true,
},
{
desc: "boolean or string union - non-boolean becomes string",
paramType: api.PropertyType{"boolean", "string"},
raw: "yes",
want: "yes",
},
{
desc: "string or boolean union - valid boolean (precedence test)",
paramType: api.PropertyType{"string", "boolean"},
raw: "false",
want: false, // Should be boolean, not string "false"
},
{
desc: "integer or number union - integer value",
paramType: api.PropertyType{"integer", "number"},
raw: "42",
want: 42,
},
{
desc: "integer or number union - float value",
paramType: api.PropertyType{"integer", "number"},
raw: "42.5",
want: 42.5,
},
{
desc: "number or integer union - integer value (precedence test)",
paramType: api.PropertyType{"number", "integer"},
raw: "42",
want: 42, // Should try integer first due to precedence
},
{
desc: "array or object union - valid array",
paramType: api.PropertyType{"array", "object"},
raw: `[1, 2, 3]`,
want: []any{float64(1), float64(2), float64(3)},
},
{
desc: "array or object union - valid object",
paramType: api.PropertyType{"array", "object"},
raw: `{"key": "value"}`,
want: map[string]any{"key": "value"},
},
{
desc: "object or array union - valid array (precedence test)",
paramType: api.PropertyType{"object", "array"},
raw: `[1, 2, 3]`,
want: []any{float64(1), float64(2), float64(3)},
},
{
desc: "complex multi-type union - null",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "null",
want: nil,
},
{
desc: "complex multi-type union - boolean",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "true",
want: true,
},
{
desc: "complex multi-type union - number",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "3.14",
want: 3.14,
},
{
desc: "complex multi-type union - string",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "hello",
want: "hello",
},
{
desc: "integer string union - integer string becomes integer",
paramType: api.PropertyType{"integer", "string"},
raw: "123",
want: 123,
},
{
desc: "string integer union - integer string becomes integer (precedence)",
paramType: api.PropertyType{"string", "integer"},
raw: "123",
want: 123, // Integer has higher precedence than string
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got := parseValue(tc.raw, tc.paramType)
if !reflect.DeepEqual(got, tc.want) {
t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
}
})
}
}
func TestQwenXMLTransform(t *testing.T) {
cases := []struct {
desc string
raw string
want string
}{
{
desc: "simple example",
raw: `<function=get_current_temperature>
<parameter=location>
San Francisco
</parameter>
<parameter=unit>
celsius
</parameter>
</function>`,
want: `<function name="get_current_temperature">
<parameter name="location">
San Francisco
</parameter>
<parameter name="unit">
celsius
</parameter>
</function>`,
},
// even though quotes aren't expected in these tags, we have these tests to
// make sure they're escaped so they don't blow up the xml parser in case
// they happen
{
desc: "names with quotes",
raw: `<function="get current temperature">
<parameter="location with spaces">
San Francisco
</parameter>
<parameter="unit with spaces">
celsius
</parameter>
</function>`,
want: `<function name="&#34;get current temperature&#34;">
<parameter name="&#34;location with spaces&#34;">
San Francisco
</parameter>
<parameter name="&#34;unit with spaces&#34;">
celsius
</parameter>
</function>`,
},
}
for _, tc := range cases {
got := transformToXML(tc.raw)
if got != tc.want {
t.Errorf("got %q, want %q", got, tc.want)
}
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
cases := []struct {
desc string
s string
want int
}{
{desc: "no whitespace", s: "abc", want: 0},
{desc: "trailing whitespace", s: "abc ", want: 1},
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
{desc: "only whitespace", s: " \n ", want: 4},
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
}
for _, tc := range cases {
got := trailingWhitespaceLen(tc.s)
if got != tc.want {
t.Errorf("got %d, want %d", got, tc.want)
}
}
}

View File

@ -0,0 +1,217 @@
package renderers
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/ollama/ollama/api"
)
var (
imStartTag = "<|im_start|>"
imEndTag = "<|im_end|>"
)
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
// This follows the same approach from the reference implementation, which gives
// a particular key ordering
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
data, err := json.Marshal(obj)
if err != nil {
return ""
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return ""
}
var sb strings.Builder
for key, value := range m {
if handledKeys[key] {
continue
}
// Check if value is a map or array (needs JSON serialization)
switch v := value.(type) {
case map[string]any, []any:
jsonBytes, _ := json.Marshal(v)
// TODO(drifkin): it would be nice to format the JSON here similarly to
// python's default json.dumps behavior (spaces after commas and colons).
// This would let us be byte-for-byte compatible with the reference
// implementation for most common inputs
jsonStr := string(jsonBytes)
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
case nil:
continue
default:
// Simple types, convert to string
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
}
}
return sb.String()
}
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
// filter out system messages and choose the first (if any) to win
var systemMessage string
var filteredMessages []api.Message
for _, message := range messages {
if message.Role != "system" {
filteredMessages = append(filteredMessages, message)
continue
}
if systemMessage == "" {
systemMessage = message.Content
}
}
if systemMessage != "" || len(tools) > 0 {
sb.WriteString(imStartTag + "system\n")
// if we have tools but no system message, match the reference implementation by providing a default system message
if systemMessage == "" {
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
}
sb.WriteString(systemMessage)
if len(tools) > 0 {
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
sb.WriteString("<tools>")
for _, tool := range tools {
sb.WriteString("\n")
sb.WriteString("<function>\n")
sb.WriteString("<name>" + tool.Function.Name + "</name>")
if tool.Function.Description != "" {
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
}
sb.WriteString("\n<parameters>")
for name, prop := range tool.Function.Parameters.Properties {
sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + name + "</name>")
if len(prop.Type) > 0 {
// TODO(!!!)(drifkin): we should match the reference implementation for
// more complex types here instead of using this format
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
}
if prop.Description != "" {
sb.WriteString("\n<description>" + prop.Description + "</description>")
}
// Render any additional keys not already handled
handledKeys := map[string]bool{
"type": true,
"description": true,
}
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
sb.WriteString("\n</parameter>")
}
// Render extra keys for parameters (everything except 'type' and 'properties')
paramHandledKeys := map[string]bool{
"type": true,
"properties": true,
}
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
sb.WriteString("\n</parameters>")
sb.WriteString("\n</function>")
}
sb.WriteString("\n</tools>")
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
}
sb.WriteString(imEndTag + "\n")
}
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
prefill := lastMessage && message.Role == "assistant"
switch message.Role {
case "assistant":
if len(message.ToolCalls) > 0 {
sb.WriteString(imStartTag + "assistant\n")
if message.Content != "" {
sb.WriteString(message.Content + "\n")
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
for name, value := range toolCall.Function.Arguments {
valueStr := formatToolCallArgument(value)
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
}
sb.WriteString("\n</function>\n</tool_call>")
}
sb.WriteString("<|im_end|>\n")
} else {
sb.WriteString(imStartTag + "assistant\n")
sb.WriteString(message.Content)
if !prefill {
sb.WriteString(imEndTag + "\n")
}
}
case "tool":
// consecutive tool responses should share a single `<im_start>user`, but
// have their own <tool_response> tags
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user\n")
}
sb.WriteString("<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
sb.WriteString(imEndTag + "\n")
}
default:
sb.WriteString(imStartTag + message.Role + "\n")
sb.WriteString(message.Content)
sb.WriteString(imEndTag + "\n")
}
if lastMessage && !prefill {
sb.WriteString(imStartTag + "assistant\n")
}
}
return sb.String(), nil
}
func formatToolCallArgument(value any) string {
if value == nil {
return "null"
}
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
}
if reflect.TypeOf(value) != nil {
kind := reflect.TypeOf(value).Kind()
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
if marshalled, err := json.Marshal(value); err == nil {
return string(marshalled)
}
}
}
return fmt.Sprintf("%v", value)
}

View File

@ -0,0 +1,338 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestQwen3CoderRenderer(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "basic",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
`,
},
{
name: "with tools and response",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in San Francisco?"},
{
Role: "assistant",
Content: "I'll check the weather in San Francisco for you.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"unit": "fahrenheit",
},
},
},
},
},
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
{Role: "user", Content: "That sounds nice! What about New York?"},
},
tools: []api.Tool{
{Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Required: []string{"unit"},
Properties: map[string]api.ToolProperty{
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
// TODO(drifkin): add multiple params back once we have predictable
// order via some sort of ordered map type (see
// <https://github.com/ollama/ollama/issues/12244>)
/*
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
*/
},
},
}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
# Tools
You have access to the following functions:
<tools>
<function>
<name>get_weather</name>
<description>Get the current weather in a given location</description>
<parameters>
<parameter>
<name>unit</name>
<type>string</type>
<description>The unit of temperature</description>
<enum>["celsius","fahrenheit"]</enum>
</parameter>
<required>["unit"]</required>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
<|im_start|>user
What is the weather like in San Francisco?<|im_end|>
<|im_start|>assistant
I'll check the weather in San Francisco for you.
<tool_call>
<function=get_weather>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
</tool_response>
<|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
`,
},
{
name: "parallel tool calls",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "call double(1) and triple(2)"},
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
}},
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
},
tools: []api.Tool{
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
}}}},
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
}}}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
# Tools
You have access to the following functions:
<tools>
<function>
<name>double</name>
<description>Double a number</description>
<parameters>
<parameter>
<name>number</name>
<type>string</type>
<description>The number to double</description>
</parameter>
</parameters>
</function>
<function>
<name>triple</name>
<description>Triple a number</description>
<parameters>
<parameter>
<name>number</name>
<type>string</type>
<description>The number to triple</description>
</parameter>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
<|im_start|>user
call double(1) and triple(2)<|im_end|>
<|im_start|>assistant
I'll call double(1) and triple(2) for you.
<tool_call>
<function=double>
<parameter=number>
1
</parameter>
</function>
</tool_call>
<tool_call>
<function=triple>
<parameter=number>
2
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"number": 2}
</tool_response>
<tool_response>
{"number": 6}
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
{
name: "prefill",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Tell me something interesting."},
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
},
expected: `<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Tell me something interesting.<|im_end|>
<|im_start|>assistant
I'll tell you something interesting about cats`,
},
{
name: "complex tool call arguments should remain json encoded",
msgs: []api.Message{
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: map[string]any{
"payload": map[string]any{"foo": "bar"},
},
}},
}},
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
},
expected: `<|im_start|>user
call tool<|im_end|>
<|im_start|>assistant
<tool_call>
<function=echo>
<parameter=payload>
{"foo":"bar"}
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"payload": {"foo": "bar"}}
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestFormatToolCallArgument(t *testing.T) {
tests := []struct {
name string
arg any
expected string
}{
{
name: "string",
arg: "foo",
// notice no quotes around the string
expected: "foo",
},
{
name: "map",
arg: map[string]any{"foo": "bar"},
expected: "{\"foo\":\"bar\"}",
},
{
name: "number",
arg: 1,
expected: "1",
},
{
name: "boolean",
arg: true,
expected: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := formatToolCallArgument(tt.arg)
if got != tt.expected {
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
}
})
}
}

View File

@ -0,0 +1,26 @@
package renderers
import (
"fmt"
"github.com/ollama/ollama/api"
)
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
renderer := rendererForName(name)
if renderer == nil {
return "", fmt.Errorf("unknown renderer %q", name)
}
return renderer(msgs, tools, think)
}
func rendererForName(name string) rendererFunc {
switch name {
case "qwen3-coder":
return Qwen3CoderRenderer
default:
return nil
}
}

View File

@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁" const spmWhitespaceSep = "▁"
type SentencePieceModel struct { type SentencePiece struct {
maxTokenLen int maxTokenLen int
vocab *Vocabulary vocab *Vocabulary
} }
var _ TextProcessor = (*SentencePieceModel)(nil) var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePieceModel) Vocabulary() *Vocabulary { func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab return spm.vocab
} }
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{} counter := map[int]int{}
@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen) "max token len", maxTokenLen)
return SentencePieceModel{ return SentencePiece{
maxTokenLen: maxTokenLen, maxTokenLen: maxTokenLen,
vocab: vocab, vocab: vocab,
} }
} }
func (spm SentencePieceModel) Is(id int32, special Special) bool { func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special) return spm.vocab.Is(id, special)
} }
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() { for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special) id := spm.vocab.Encode(special)
@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item return item
} }
func (spm SentencePieceModel) Decode(ids []int32) (string, error) { func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder var sb strings.Builder
for _, id := range ids { for _, id := range ids {
data := spm.vocab.Decode(id) data := spm.vocab.Decode(id)

View File

@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece" "github.com/ollama/ollama/convert/sentencepiece"
) )
func loadSentencePieceVocab(t *testing.T) SentencePieceModel { func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper() t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
} }
} }
return NewSentencePieceModel(&v) return NewSentencePiece(&v)
} }
func TestSentencePieceEncode(t *testing.T) { func TestSentencePieceEncode(t *testing.T) {
@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
}) })
} }
func TestSentencePieceModelDecodeByteTokens(t *testing.T) { func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{ vocab := &Vocabulary{
Values: []string{ Values: []string{
"normal", "normal",
@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0}, Scores: []float32{0, 0, 0, 0, 0},
} }
spm := NewSentencePieceModel(vocab) spm := NewSentencePiece(vocab)
tests := []struct { tests := []struct {
name string name string

167
model/wordpiece.go Normal file
View File

@ -0,0 +1,167 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial && len(ids) > 0 {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary) WordPiece {
return WordPiece{
vocab: vocab,
}
}

51
model/wordpiece_test.go Normal file
View File

@ -0,0 +1,51 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
})
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@ -78,6 +78,7 @@ type JsonSchema struct {
type EmbedRequest struct { type EmbedRequest struct {
Input any `json:"input"` Input any `json:"input"`
Model string `json:"model"` Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
} }
type StreamOptions struct { type StreamOptions struct {
@ -104,6 +105,7 @@ type ChatCompletionRequest struct {
Tools []api.Tool `json:"tools"` Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"`
DebugRenderOnly bool `json:"_debug_render_only"`
} }
type ChatCompletion struct { type ChatCompletion struct {
@ -114,6 +116,7 @@ type ChatCompletion struct {
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"` Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"` Usage Usage `json:"usage,omitempty"`
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
} }
type ChatCompletionChunk struct { type ChatCompletionChunk struct {
@ -140,6 +143,7 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature"` Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP float32 `json:"top_p"`
Suffix string `json:"suffix"` Suffix string `json:"suffix"`
DebugRenderOnly bool `json:"_debug_render_only"`
} }
type Completion struct { type Completion struct {
@ -272,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
return nil return nil
}(r.DoneReason), }(r.DoneReason),
}}, }}, Usage: toUsage(r),
Usage: toUsage(r), DebugInfo: r.DebugInfo,
} }
} }
@ -574,6 +578,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools, Tools: r.Tools,
Think: think, Think: think,
DebugRenderOnly: r.DebugRenderOnly,
}, nil }, nil
} }
@ -652,6 +657,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Suffix: r.Suffix, Suffix: r.Suffix,
DebugRenderOnly: r.DebugRenderOnly,
}, nil }, nil
} }
@ -1005,7 +1011,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
} }
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return return
} }

View File

@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.System = c.Args req.System = c.Args
case "license": case "license":
licenses = append(licenses, c.Args) licenses = append(licenses, c.Args)
case "renderer":
req.Renderer = c.Args
case "parser":
req.Parser = c.Args
case "message": case "message":
role, msg, _ := strings.Cut(c.Args, ": ") role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg}) messages = append(messages, api.Message{Role: role, Content: msg})
@ -320,7 +324,7 @@ func (c Command) String() string {
switch c.Name { switch c.Name {
case "model": case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args) fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter": case "license", "template", "system", "adapter", "renderer", "parser":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message": case "message":
role, message, _ := strings.Cut(c.Args, ": ") role, message, _ := strings.Cut(c.Args, ": ")
@ -346,7 +350,7 @@ const (
var ( var (
errMissingFrom = errors.New("no FROM line") errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
) )
type ParserError struct { type ParserError struct {
@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool { func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) { switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "parameter", "message": case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
return true return true
default: default:
return false return false

View File

@ -198,6 +198,34 @@ BADCOMMAND param1 value1
} }
} }
func TestParseFileRenderer(t *testing.T) {
input := `
FROM foo
RENDERER renderer1
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
}
func TestParseFileParser(t *testing.T) {
input := `
FROM foo
PARSER parser1
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
}
func TestParseFileMessages(t *testing.T) { func TestParseFileMessages(t *testing.T) {
cases := []struct { cases := []struct {
input string input string

View File

@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree = max(targetFree, 1) targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 { return max(targetFree-currentFree, 0)
discard = 0
}
return discard
} }
type ErrReprocessInputs struct { type ErrReprocessInputs struct {

View File

@ -242,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree = max(targetFree, 1) targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 { return max(targetFree-currentFree, 0)
discard = 0
}
return discard
} }
type ErrReprocessInputs struct { type ErrReprocessInputs struct {

View File

@ -11,14 +11,12 @@ import (
"image" "image"
"log" "log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -30,10 +28,10 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
@ -407,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
s.ready.Wait() s.ready.Wait()
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState var activeBatch batchState
for { for {
@ -469,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet // Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input var batchInputs []*input.Input
var batchOutputs []int32
var batch input.Batch var batch input.Batch
resumeSeq := -1 resumeSeq := -1
@ -551,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id) batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs) seq.iBatch = len(batchOutputs)
if i+1 == len(seq.inputs) { if i+1 == len(seq.inputs) || seq.embeddingOnly {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
} }
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, inp)
@ -578,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil { if err != nil {
err = fmt.Errorf("failed to build graph: %w", err) err = fmt.Errorf("failed to build graph: %w", err)
@ -705,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
} }
// sample a token // sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs) vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
@ -782,14 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if req.UseHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
if req.Options == nil { if req.Options == nil {
opts := api.DefaultOptions() opts := api.DefaultOptions()
req.Options = &opts req.Options = &opts
@ -872,9 +864,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
http.Error(w, "could not find an available sequence", http.StatusInternalServerError) http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return return
} }
var lastToken string
tokenRepeat := 0
const tokenRepeatLimit = 30
for { for {
select { select {
@ -883,27 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if strings.TrimSpace(content) == lastToken {
tokenRepeat++
}
if tokenRepeat == tokenRepeatLimit {
http.Error(w, "token repeat limit reached", http.StatusInternalServerError)
seq.doneReason = llm.DoneReasonTokenRepeatLimit
close(seq.quit)
return
}
lastToken = strings.TrimSpace(content)
var thinking string
if harmonyMessageHandler != nil {
var toolContent string
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
harmonyToolParser.Add(toolContent)
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content, Content: content,
Thinking: thinking,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
@ -912,29 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
var toolCalls []api.ToolCall
if harmonyMessageHandler != nil {
// these tools still need to be transformed to the original function name
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
close(seq.quit)
return
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true, Done: true,
DoneReason: seq.doneReason, DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs, PromptEvalCount: seq.numPromptInputs,
@ -952,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 { if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return return
} }
@ -1100,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i) batch.Positions[i] = int32(i)
} }
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache cache := s.model.Config().Cache
if cache != nil { if cache != nil {
@ -1139,9 +1083,13 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors // Convert memory allocation panics to errors
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok { if err, ok := r.(error); ok {
panicErr = err var noMem ml.ErrNoMem
if errors.As(err, &noMem) {
panicErr = noMem
} else {
panic(r)
}
} else { } else {
panic(r) panic(r)
} }

View File

@ -78,7 +78,7 @@ function checkEnv() {
} }
function buildOllama() { function buildCPU() {
mkdir -Force -path "${script:DIST_DIR}\" mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") { if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
@ -90,20 +90,72 @@ function buildOllama() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip & cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
function buildCUDA11() {
# CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible
# 19.40 is the last compiler version that works, but recent udpates are 19.43
# So this pins to MSVC 2019 for best compatibility
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{} $hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12")) { if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }} $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12] write-host "Building CUDA v11 backend libraries $cuda"
write-host "Building CUDA v12 backend libraries" $env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR & cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA12() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12.8")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v12 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS & cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip & cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} }
}
}
function buildCUDA13() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v13")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
$env:CUDAToolkit_ROOT=$cuda
write-host "Building CUDA v13 backend libraries $cuda"
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildROCm() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
if ($env:HIP_PATH) { if ($env:HIP_PATH) {
write-host "Building ROCm backend libraries" write-host "Building ROCm backend libraries"
if (-Not (get-command -ErrorAction silent ninja)) { if (-Not (get-command -ErrorAction silent ninja)) {
@ -129,6 +181,10 @@ function buildOllama() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} }
} }
}
function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
write-host "Building ollama CLI" write-host "Building ollama CLI"
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@ -236,6 +292,10 @@ function distZip() {
checkEnv checkEnv
try { try {
if ($($args.count) -eq 0) { if ($($args.count) -eq 0) {
buildCPU
buildCUDA12
buildCUDA13
buildROCm
buildOllama buildOllama
buildApp buildApp
gatherDependencies gatherDependencies

View File

@ -10,8 +10,11 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"slices" "slices"
"strings" "strings"
@ -39,6 +42,14 @@ var (
) )
func (s *Server) CreateHandler(c *gin.Context) { func (s *Server) CreateHandler(c *gin.Context) {
config := &ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
var r api.CreateRequest var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
return return
} }
config.Renderer = r.Renderer
config.Parser = r.Parser
for v := range r.Files { for v := range r.Files {
if !fs.ValidPath(v) { if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
@ -77,14 +91,27 @@ func (s *Server) CreateHandler(c *gin.Context) {
oldManifest, _ := ParseNamedManifest(name) oldManifest, _ := ParseNamedManifest(name)
var baseLayers []*layerGGML var baseLayers []*layerGGML
var err error
var remote bool
if r.From != "" { if r.From != "" {
slog.Debug("create model from model name") slog.Debug("create model from model name", "from", r.From)
fromName := model.ParseName(r.From) fromName := model.ParseName(r.From)
if !fromName.IsValid() { if !fromName.IsValid() {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return return
} }
if r.RemoteHost != "" {
ru, err := remoteURL(r.RemoteHost)
if err != nil {
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
return
}
config.RemoteModel = r.From
config.RemoteHost = ru
remote = true
} else {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
@ -92,6 +119,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}
} else if r.Files != nil { } else if r.Files != nil {
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn) baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
if err != nil { if err != nil {
@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
} }
var adapterLayers []*layerGGML var adapterLayers []*layerGGML
if r.Adapters != nil { if !remote && r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil { if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) {
baseLayers = append(baseLayers, adapterLayers...) baseLayers = append(baseLayers, adapterLayers...)
} }
if err := createModel(r, name, baseLayers, fn); err != nil { // Info is not currently exposed by Modelfiles, but allows overriding various
// config values
if r.Info != nil {
caps, ok := r.Info["capabilities"]
if ok {
switch tcaps := caps.(type) {
case []any:
caps := make([]string, len(tcaps))
for i, c := range tcaps {
str, ok := c.(string)
if !ok {
continue
}
caps[i] = str
}
config.Capabilities = append(config.Capabilities, caps...)
}
}
strFromInfo := func(k string) string {
v, ok := r.Info[k]
if ok {
val := v.(string)
return val
}
return ""
}
vFromInfo := func(k string) float64 {
v, ok := r.Info[k]
if ok {
val := v.(float64)
return val
}
return 0
}
config.ModelFamily = strFromInfo("model_family")
if config.ModelFamily != "" {
config.ModelFamilies = []string{config.ModelFamily}
}
config.BaseName = strFromInfo("base_name")
config.FileType = strFromInfo("quantization_level")
config.ModelType = strFromInfo("parameter_size")
config.ContextLen = int(vFromInfo("context_length"))
config.EmbedLen = int(vFromInfo("embedding_length"))
}
if err := createModel(r, name, baseLayers, config, fn); err != nil {
if errors.Is(err, errBadTemplate) { if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return return
@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func remoteURL(raw string) (string, error) {
// Specialcase: user supplied only a path ("/foo/bar").
if strings.HasPrefix(raw, "/") {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("localhost", "11434"),
Path: path.Clean(raw),
}).String(), nil
}
if !strings.Contains(raw, "://") {
raw = "http://" + raw
}
if raw == "ollama.com" || raw == "http://ollama.com" {
raw = "https://ollama.com:443"
}
u, err := url.Parse(raw)
if err != nil {
return "", fmt.Errorf("parse error: %w", err)
}
if u.Host == "" {
u.Host = "localhost"
}
hostPart, portPart, err := net.SplitHostPort(u.Host)
if err == nil {
u.Host = net.JoinHostPort(hostPart, portPart)
} else {
u.Host = net.JoinHostPort(u.Host, "11434")
}
if u.Path != "" {
u.Path = path.Clean(u.Path)
}
if u.Path == "/" {
u.Path = ""
}
return u.String(), nil
}
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
switch detectModelTypeFromFiles(files) { switch detectModelTypeFromFiles(files) {
case "safetensors": case "safetensors":
@ -316,15 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
return ggml.KV{}, fmt.Errorf("no base model was found") return ggml.KV{}, fmt.Errorf("no base model was found")
} }
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) { func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
var layers []Layer var layers []Layer
for _, layer := range baseLayers { for _, layer := range baseLayers {
if layer.GGML != nil { if layer.GGML != nil {
@ -404,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return err return err
} }
configLayer, err := createConfigLayer(layers, config) configLayer, err := createConfigLayer(layers, *config)
if err != nil { if err != nil {
return err return err
} }

View File

@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) {
}) })
} }
} }
func TestRemoteURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
hasError bool
}{
{
name: "absolute path",
input: "/foo/bar",
expected: "http://localhost:11434/foo/bar",
hasError: false,
},
{
name: "absolute path with cleanup",
input: "/foo/../bar",
expected: "http://localhost:11434/bar",
hasError: false,
},
{
name: "root path",
input: "/",
expected: "http://localhost:11434/",
hasError: false,
},
{
name: "host without scheme",
input: "example.com",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "host with port",
input: "example.com:8080",
expected: "http://example.com:8080",
hasError: false,
},
{
name: "full URL",
input: "https://example.com:8080/path",
expected: "https://example.com:8080/path",
hasError: false,
},
{
name: "full URL with path cleanup",
input: "https://example.com:8080/path/../other",
expected: "https://example.com:8080/other",
hasError: false,
},
{
name: "ollama.com special case",
input: "ollama.com",
expected: "https://ollama.com:443",
hasError: false,
},
{
name: "http ollama.com special case",
input: "http://ollama.com",
expected: "https://ollama.com:443",
hasError: false,
},
{
name: "URL with only host",
input: "http://example.com",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "URL with root path cleaned",
input: "http://example.com/",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "invalid URL",
input: "http://[::1]:namedport", // invalid port
expected: "",
hasError: true,
},
{
name: "empty string",
input: "",
expected: "http://localhost:11434",
hasError: false,
},
{
name: "host with scheme but no port",
input: "http://localhost",
expected: "http://localhost:11434",
hasError: false,
},
{
name: "complex path cleanup",
input: "/a/b/../../c/./d",
expected: "http://localhost:11434/c/d",
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := remoteURL(tt.input)
if tt.hasError {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestRemoteURL_Idempotent(t *testing.T) {
// Test that applying remoteURL twice gives the same result as applying it once
testInputs := []string{
"/foo/bar",
"example.com",
"https://example.com:8080/path",
"ollama.com",
"http://localhost:11434",
}
for _, input := range testInputs {
t.Run(input, func(t *testing.T) {
firstResult, err := remoteURL(input)
if err != nil {
t.Fatalf("first call failed: %v", err)
}
secondResult, err := remoteURL(firstResult)
if err != nil {
t.Fatalf("second call failed: %v", err)
}
if firstResult != secondResult {
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
}
})
}
}

View File

@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf" "github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking" "github.com/ollama/ollama/thinking"
@ -73,6 +74,7 @@ func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{} capabilities := []model.Capability{}
// Check for completion capability // Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath) f, err := gguf.Open(m.ModelPath)
if err == nil { if err == nil {
defer f.Close() defer f.Close()
@ -89,13 +91,21 @@ func (m *Model) Capabilities() []model.Capability {
} else { } else {
slog.Error("couldn't open model file", "error", err) slog.Error("couldn't open model file", "error", err)
} }
} else if len(m.Config.Capabilities) > 0 {
for _, c := range m.Config.Capabilities {
capabilities = append(capabilities, model.Capability(c))
}
} else {
slog.Warn("unknown capabilities for model", "model", m.Name)
}
if m.Template == nil { if m.Template == nil {
return capabilities return capabilities
} }
builtinParser := parsers.ParserForName(m.Config.Parser)
// Check for tools capability // Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") { if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
capabilities = append(capabilities, model.CapabilityTools) capabilities = append(capabilities, model.CapabilityTools)
} }
@ -109,10 +119,16 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityVision) capabilities = append(capabilities, model.CapabilityVision)
} }
// Skip the thinking check if it's already set
if slices.Contains(capabilities, "thinking") {
return capabilities
}
// Check for thinking capability // Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template) openingTag, closingTag := thinking.InferTags(m.Template.Template)
hasTags := openingTag != "" && closingTag != "" hasTags := openingTag != "" && closingTag != ""
if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) { isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily)
if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
capabilities = append(capabilities, model.CapabilityThinking) capabilities = append(capabilities, model.CapabilityThinking)
} }
@ -198,6 +214,20 @@ func (m *Model) String() string {
}) })
} }
if m.Config.Renderer != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "renderer",
Args: m.Config.Renderer,
})
}
if m.Config.Parser != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "parser",
Args: m.Config.Parser,
})
}
for k, v := range m.Options { for k, v := range m.Options {
switch v := v.(type) { switch v := v.(type) {
case []any: case []any:
@ -236,8 +266,19 @@ type ConfigV2 struct {
ModelFormat string `json:"model_format"` ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"` ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"` ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"` ModelType string `json:"model_type"` // shown as Parameter Size
FileType string `json:"file_type"` FileType string `json:"file_type"` // shown as Quantization Level
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
// used for remotes
Capabilities []string `json:"capabilities,omitempty"`
ContextLen int `json:"context_length,omitempty"`
EmbedLen int `json:"embedding_length,omitempty"`
BaseName string `json:"base_name,omitempty"`
// required by spec // required by spec
Architecture string `json:"architecture"` Architecture string `json:"architecture"`

View File

@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
// n^2 backoff timer is a little smoother than the // n^2 backoff timer is a little smoother than the
// common choice of 2^n. // common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order // Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems. // to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5)) d = time.Duration(float64(d) * (rand.Float64() + 0.5))

View File

@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
thinkVal := false p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
thinkLevel := "" if err != nil {
if think != nil {
thinkVal = think.Bool()
thinkLevel = think.String()
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", nil, err
} }
s, err := tokenize(ctx, b.String()) s, err := tokenize(ctx, p)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
if err != nil {
return "", nil, err
}
return p, images, nil
}
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" {
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
if err != nil {
return "", err
}
return rendered, nil
}
var b bytes.Buffer var b bytes.Buffer
thinkVal := false thinkVal := false
thinkLevel := "" thinkLevel := ""
@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
thinkVal = think.Bool() thinkVal = think.Bool()
thinkLevel = think.String() thinkLevel = think.String()
} }
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", err
} }
return b.String(), nil
return b.String(), images, nil
} }

View File

@ -15,6 +15,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"net/url"
"os" "os"
"os/signal" "os/signal"
"slices" "slices"
@ -28,6 +29,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover" "github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
@ -35,6 +37,7 @@ import (
"github.com/ollama/ollama/harmony" "github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
@ -46,6 +49,18 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func experimentEnabled(name string) bool { func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
} }
@ -176,6 +191,84 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
return
}
req.Model = m.Config.RemoteModel
if req.Template == "" && m.Template.String() != "" {
req.Template = m.Template.String()
}
if req.Options == nil {
req.Options = map[string]any{}
}
for k, v := range m.Options {
if _, ok := req.Options[k]; !ok {
req.Options[k] = v
}
}
// update the system prompt from the model if one isn't already specified
if req.System == "" && m.System != "" {
req.System = m.System
}
if len(m.Messages) > 0 {
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
}
fn := func(resp api.GenerateResponse) error {
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
data, err := json.Marshal(resp)
if err != nil {
return err
}
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
return err
}
c.Writer.Flush()
return nil
}
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Generate(c, &req, fn)
if err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
pk, pkErr := auth.GetPublicKey()
if pkErr != nil {
slog.Error("couldn't get public key", "error", pkErr)
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"public_key": pk})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
// expire the runner // expire the runner
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m) s.sched.expireRunner(m)
@ -195,11 +288,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw useHarmony := shouldUseHarmony(m) && !req.Raw
var functionNameMap *harmony.FunctionNameMap var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony { if useHarmony {
functionNameMap = harmony.NewFunctionNameMap() harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
} }
// Validate Think value: string values currently only allowed for gptoss models // Validate Think value: string values currently only allowed for gptoss models
@ -315,10 +410,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model // If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly { if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.DebugTemplateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
DebugInfo: api.DebugInfo{ DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt, RenderedTemplate: prompt,
ImageCount: len(images), ImageCount: len(images),
}, },
@ -334,6 +429,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
OpeningTag: openingTag, OpeningTag: openingTag,
ClosingTag: closingTag, ClosingTag: closingTag,
} }
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
thinkingState.AddContent(openingTag)
}
} }
} }
@ -347,15 +445,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
UseHarmony: useHarmony,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: cr.Content, Response: cr.Content,
Done: cr.Done, Done: cr.Done,
Thinking: cr.Thinking,
ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
@ -364,22 +459,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}, },
} }
if res.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if useHarmony { if useHarmony {
for i, tool := range res.ToolCalls { content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) res.Response = content
} res.Thinking = thinking
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { harmonyToolParser.Add(toolContent)
ch <- res } else if thinkingState != nil {
}
return
}
if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content) thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking res.Thinking = thinking
res.Response = content res.Response = content
@ -390,6 +475,30 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if cr.Done { if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil { if err != nil {
@ -463,7 +572,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
truncate := true truncate := true
if req.Truncate != nil && !*req.Truncate { if req.Truncate != nil && !*req.Truncate {
truncate = false truncate = false
} }
@ -530,7 +638,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
tokens = tokens[:ctxLen] tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens) s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -551,7 +668,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
if err != nil { if err != nil {
return err return err
} }
embeddings[i] = normalize(embedding) // TODO: this first normalization should be done by the model
embedding = normalize(embedding)
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
embedding = normalize(embedding[:req.Dimensions])
}
embeddings[i] = embedding
return nil return nil
}) })
} }
@ -577,11 +699,7 @@ func normalize(vec []float32) []float32 {
sum += v * v sum += v * v
} }
norm := float32(0.0) norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12))
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec { for i := range vec {
vec[i] *= norm vec[i] *= norm
} }
@ -896,6 +1014,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
ModifiedAt: manifest.fi.ModTime(), ModifiedAt: manifest.fi.ModTime(),
} }
if m.Config.RemoteHost != "" {
resp.RemoteHost = m.Config.RemoteHost
resp.RemoteModel = m.Config.RemoteModel
if m.Config.ModelFamily != "" {
resp.ModelInfo = make(map[string]any)
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
if m.Config.BaseName != "" {
resp.ModelInfo["general.basename"] = m.Config.BaseName
}
if m.Config.ContextLen > 0 {
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
}
if m.Config.EmbedLen > 0 {
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
}
}
}
var params []string var params []string
cs := 30 cs := 30
for k, v := range m.Options { for k, v := range m.Options {
@ -926,6 +1066,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String()) fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String() resp.Modelfile = sb.String()
// skip loading tensor information if this is a remote model
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1004,6 +1149,8 @@ func (s *Server) ListHandler(c *gin.Context) {
models = append(models, api.ListModelResponse{ models = append(models, api.ListModelResponse{
Model: n.DisplayShortest(), Model: n.DisplayShortest(),
Name: n.DisplayShortest(), Name: n.DisplayShortest(),
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(), Size: m.Size(),
Digest: m.digest, Digest: m.digest,
ModifiedAt: m.fi.ModTime(), ModifiedAt: m.fi.ModTime(),
@ -1266,6 +1413,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/show", s.ShowHandler) r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler) r.DELETE("/api/delete", s.DeleteHandler)
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
r.POST("/api/me", s.WhoamiHandler)
// Create // Create
r.POST("/api/create", s.CreateHandler) r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
@ -1462,6 +1612,49 @@ func streamResponse(c *gin.Context, ch chan any) {
}) })
} }
func (s *Server) WhoamiHandler(c *gin.Context) {
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
if err != nil {
slog.Error(err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
return
}
client := api.NewClient(u, http.DefaultClient)
user, err := client.Whoami(c)
if err != nil {
slog.Error(err.Error())
}
c.JSON(http.StatusOK, user)
}
func (s *Server) SignoutHandler(c *gin.Context) {
encodedKey := c.Param("encodedKey")
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
if err != nil {
slog.Error(err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
return
}
client := api.NewClient(u, http.DefaultClient)
err = client.Signout(c, encodedKey)
if err != nil {
slog.Error(err.Error())
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
c.JSON(http.StatusOK, nil)
}
func (s *Server) PsHandler(c *gin.Context) { func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{} models := []api.ProcessModelResponse{}
@ -1518,9 +1711,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
// expire the runner name := model.ParseName(req.Model)
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { if !name.IsValid() {
model, err := GetModel(req.Model) c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(req.Model)
if err != nil { if err != nil {
switch { switch {
case os.IsNotExist(err): case os.IsNotExist(err):
@ -1532,7 +1735,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
return return
} }
s.sched.expireRunner(model)
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.ChatResponse{ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model, Model: req.Model,
@ -1544,6 +1750,66 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
return
}
req.Model = m.Config.RemoteModel
if req.Options == nil {
req.Options = map[string]any{}
}
msgs := append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
msgs = filterThinkTags(msgs, m)
req.Messages = msgs
for k, v := range m.Options {
if _, ok := req.Options[k]; !ok {
req.Options[k] = v
}
}
fn := func(resp api.ChatResponse) error {
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
data, err := json.Marshal(resp)
if err != nil {
return err
}
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
return err
}
c.Writer.Flush()
return nil
}
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Chat(c, &req, fn)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
caps := []model.Capability{model.CapabilityCompletion} caps := []model.Capability{model.CapabilityCompletion}
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools) caps = append(caps, model.CapabilityTools)
@ -1552,17 +1818,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, model.CapabilityThinking) caps = append(caps, model.CapabilityThinking)
} }
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@ -1591,21 +1846,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) var builtinParser parsers.Parser
if m.Config.Parser != "" {
builtinParser = parsers.ParserForName(m.Config.Parser)
}
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
processedTools := req.Tools processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillString string
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
if useHarmony { if useHarmony {
prefillString = harmony.Prefill(msgs[len(msgs)-1]) harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
functionNameMap = harmony.NewFunctionNameMap() var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
// make a copy of tools to pass to the chat prompt. Function names may be // make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names. // renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools)) processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools) copy(processedTools, req.Tools)
for i, tool := range processedTools { for i, tool := range processedTools {
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
} }
} }
@ -1618,10 +1884,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model // If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly { if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.DebugTemplateResponse{ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
DebugInfo: api.DebugInfo{ DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt, RenderedTemplate: prompt,
ImageCount: len(images), ImageCount: len(images),
}, },
@ -1662,13 +1928,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
UseHarmony: useHarmony,
PrefillString: prefillString,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, Done: r.Done,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
@ -1683,14 +1947,54 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic
if useHarmony { if useHarmony {
for i, tool := range res.Message.ToolCalls { content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) res.Message.Content = content
res.Message.Thinking = thinking
harmonyToolParser.Add(toolContent)
if r.Done {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
} }
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
}
// only send messages with meaningful content (empty messages confuse clients) // only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res ch <- res
} }
return
} else if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Message.Content = content
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
return return
} }

View File

@ -11,6 +11,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"slices" "slices"
"strings" "strings"
"testing" "testing"
@ -20,6 +21,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
) )
var stream bool = false var stream bool = false
@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) {
}) })
} }
func TestCreateAndShowRemoteModel(t *testing.T) {
gin.SetMode(gin.TestMode)
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
From: "bob",
RemoteHost: "https://ollama.com",
Info: map[string]any{
"capabilities": []string{"completion", "tools", "thinking"},
"model_family": "gptoss",
"context_length": 131072,
"embedding_length": 2880,
"quantization_level": "MXFP4",
"parameter_size": "20.9B",
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("exected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"})
if w.Code != http.StatusOK {
t.Fatalf("exected status code 200, actual %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
expectedDetails := api.ModelDetails{
ParentModel: "",
Format: "",
Family: "gptoss",
Families: []string{"gptoss"},
ParameterSize: "20.9B",
QuantizationLevel: "MXFP4",
}
if !reflect.DeepEqual(resp.Details, expectedDetails) {
t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details)
}
expectedCaps := []model.Capability{
model.Capability("completion"),
model.Capability("tools"),
model.Capability("thinking"),
}
if !slices.Equal(resp.Capabilities, expectedCaps) {
t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities)
}
v, ok := resp.ModelInfo["gptoss.context_length"]
ctxlen := v.(float64)
if !ok || int(ctxlen) != 131072 {
t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen))
}
v, ok = resp.ModelInfo["gptoss.embedding_length"]
embedlen := v.(float64)
if !ok || int(embedlen) != 2880 {
t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen))
}
fmt.Printf("resp = %#v\n", resp)
}
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
} }
var response api.DebugTemplateResponse var response api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err) t.Fatalf("failed to unmarshal response: %v", err)
} }
@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
} }
var response api.DebugTemplateResponse var response api.ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err) t.Fatalf("failed to unmarshal response: %v", err)
} }

View File

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "content streams as it arrives", name: "content streams as it arrives",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Content: "Hello", Done: false}, input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
wantContent: "Hello", wantContent: "Hello",
}, },
{ {
@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantContent: ", world", wantContent: ", world",
}, },
{ {
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!", wantContent: "!",
}, },
}, },
@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "thinking streams separately from content", name: "thinking streams separately from content",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
wantThinking: "Thinking...", wantThinking: "Thinking...",
}, },
{ {
input: llm.CompletionResponse{Content: "Answer", Done: false}, input: llm.CompletionResponse{Content: "<|end|>", Done: false},
wantContent: "Answer", // No output expected - just closes the analysis message and resets state to normal
}, },
{ {
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
wantContent: "Answer", // After message end, state is reset to normal
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
// No output expected - just closes the assistant message
}, },
}, },
}, },
@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "partial tags buffer until complete", name: "partial tags buffer until complete",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, input: llm.CompletionResponse{Content: "<|chan", Done: false},
// No output - partial tag
},
{
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
// No output - still building tags
},
{
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
wantThinking: "Deep ", wantThinking: "Deep ",
}, },
{ {
input: llm.CompletionResponse{Thinking: "thought", Done: false}, input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
wantThinking: "thought", wantThinking: "thought",
}, },
{ {
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done", wantContent: "Done", // After message end, state is reset to normal
}, },
}, },
}, },
@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "simple assistant after analysis", name: "simple assistant after analysis",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer", wantContent: "Answer",
wantThinking: "Think", wantThinking: "Think",
}, },
@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call parsed and returned correctly", name: "tool call parsed and returned correctly",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny", wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{ wantToolCalls: []api.ToolCall{
{ {
@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call with streaming JSON across chunks", name: "tool call with streaming JSON across chunks",
steps: []step{ steps: []step{
{ {
input: llm.CompletionResponse{Done: false}, input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
// No output yet - incomplete JSON
}, },
{ {
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
// Still no output - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "2\"}", Done: true},
wantToolCalls: []api.ToolCall{ wantToolCalls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{ mockResponses := []llm.CompletionResponse{
{Content: "First ", Done: false}, {Content: "<|message|>First ", Done: false},
{Content: "chunk ", Done: false}, {Content: "chunk ", Done: false},
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
} }
mock := mockRunner{ mock := mockRunner{
@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
} }
} }
func TestChatHarmonyParserStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
type expectedChunk struct {
afterResponse int // Which mock response this chunk should appear after
content string // Expected content in this chunk
thinking string // Expected thinking in this chunk
}
testCases := []struct {
name string
mockResponses []llm.CompletionResponse
expectedChunks []expectedChunk
wantContent string
wantThinking string
}{
{
name: "simple message without thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
{Content: "how can I help?", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 1, content: "Hello, "},
{afterResponse: 2, content: "how can I help?"},
},
wantContent: "Hello, how can I help?",
},
{
name: "message with analysis channel for thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|channel|>analysis<|message|>", Done: false},
{Content: "Let me think ", Done: false},
{Content: "about this problem...", Done: false},
{Content: "<|end|>", Done: false},
{Content: "<|start|>assistant<|message|>", Done: false},
{Content: "The answer ", Done: false},
{Content: "is 42", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 2, thinking: "Let me think "},
{afterResponse: 3, thinking: "about this problem..."},
{afterResponse: 6, content: "The answer "},
{afterResponse: 7, content: "is 42"},
},
wantContent: "The answer is 42",
wantThinking: "Let me think about this problem...",
},
{
name: "streaming with partial tags across boundaries",
mockResponses: []llm.CompletionResponse{
{Content: "<|chan", Done: false},
{Content: "nel|>analy", Done: false},
{Content: "sis<|mess", Done: false},
{Content: "age|>Think", Done: false},
{Content: "ing deeply...<|end|>", Done: false},
{Content: "<|start|>assi", Done: false},
{Content: "stant<|message|>Result ", Done: false},
{Content: "computed<|e", Done: false},
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 4, thinking: "Think"},
{afterResponse: 5, thinking: "ing deeply..."},
{afterResponse: 7, content: "Result "},
{afterResponse: 8, content: "computed"},
},
wantContent: "Result computed",
wantThinking: "Thinking deeply...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Channel to synchronize mock responses with chunk verification
responsesSent := make(chan int, len(tc.mockResponses))
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Send mock responses one at a time, notifying when each is sent
for i, resp := range tc.mockResponses {
fn(resp)
responsesSent <- i + 1
}
close(responsesSent)
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a minimal model
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != http.StatusOK {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse streaming response
var chunks []api.ChatResponse
var content, thinking strings.Builder
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
// Accumulate content and thinking from each chunk
content.WriteString(chunk.Message.Content)
thinking.WriteString(chunk.Message.Thinking)
// Debug output
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got streaming chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
gotContent := content.String()
gotThinking := thinking.String()
if gotContent != tc.wantContent {
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
}
if gotThinking != tc.wantThinking {
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
}
// Verify last chunk has done=true
lastChunk := chunks[len(chunks)-1]
if !lastChunk.Done {
t.Error("expected last chunk to have done=true")
}
})
}
}

View File

@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) {
t.Fatalf("failed to create model: %v", err) t.Fatalf("failed to create model: %v", err)
} }
if err := createModel(r, modelName, baseLayers, fn); err != nil { config := &ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -382,10 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit. // (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool { func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
numParallel := int(envconfig.NumParallel()) numParallel := max(int(envconfig.NumParallel()), 1)
if numParallel < 1 {
numParallel = 1
}
// Embedding models should always be loaded with parallel=1 // Embedding models should always be loaded with parallel=1
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil { if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {