From 8d6fffaead722d86622b17c059051796d2621858 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Wed, 10 Sep 2025 11:24:42 -0700 Subject: [PATCH 01/32] runner: simplify parser entrypoints in runner (#12233) --- harmony/harmonyparser.go | 18 +++-- harmony/harmonyparser_test.go | 28 ++++---- llm/server.go | 7 +- parser/token_parser.go | 126 ++++++++++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 54 +++------------ server/routes.go | 17 ++++- 6 files changed, 173 insertions(+), 77 deletions(-) create mode 100644 parser/token_parser.go diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index addce4c94..3ec2c21f1 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -289,6 +289,7 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap + ToolParser *HarmonyToolCallAccumulator } // NewHarmonyMessageHandler creates a new message handler @@ -301,12 +302,16 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), + ToolParser: &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + }, } } // AddContent processes the content and returns the content, thinking, and tool content. // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser -func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { +func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { contentSb := strings.Builder{} thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} @@ -323,14 +328,14 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo // event.Header.Recipient is the tool name, something like // "browser.search" for a built-in, or "functions.calc" for a // custom one - toolParser.SetToolName(event.Header.Recipient) + h.ToolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Thinking } case "commentary": if event.Header.Recipient != "" { h.state = harmonyMessageState_ToolCalling - toolParser.SetToolName(event.Header.Recipient) + h.ToolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Normal } @@ -353,13 +358,6 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo return contentSb.String(), thinkingSb.String(), toolContentSb.String() } -func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { - return &HarmonyToolCallAccumulator{ - state: harmonyToolCallState_Normal, - currentToolName: nil, - } -} - type harmonyToolCallState int const ( diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index dcf1af4e8..82bf5b2de 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -541,7 +541,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_then_content_streams", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser type step struct { in string wantContent string @@ -554,7 +554,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { {in: "<|end|>", wantContent: ""}, } for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in, tp) + content, thinking, tool := handler.AddContent(s.in) if tool != "" { tp.Add(tool) } @@ -567,7 +567,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("content_streams_as_it_arrives", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|start|>assistant<|message|>Hello", ", world", @@ -575,7 +575,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -595,7 +595,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_streams_separately_from_content", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>analysis<|message|>Thinking...", "<|end|>", @@ -604,7 +604,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -624,7 +624,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|chan", "nel|>analysis<|mess", @@ -637,7 +637,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { var thinkingPieces []string var contentPieces []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -659,7 +659,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("simple_assistant_after_analysis", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>analysis<|message|>Think", "<|end|>", @@ -668,7 +668,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var contentSb, thinkingSb strings.Builder for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -686,12 +686,12 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if content != "" || thinking != "" { continue } @@ -711,14 +711,14 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_across_chunks", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", "2\"}", "<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if content != "" || thinking != "" { continue } diff --git a/llm/server.go b/llm/server.go index 7bc2ca13d..4740a1fd4 100644 --- a/llm/server.go +++ b/llm/server.go @@ -35,6 +35,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/parser" ) type filteredEnv []string @@ -1350,7 +1351,7 @@ type CompletionRequest struct { Options *api.Options Grammar string // set before sending the request to the subprocess - UseHarmony bool + ParserType parser.TokenParserType PrefillString string } @@ -1364,8 +1365,6 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed - // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit - DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1374,8 +1373,6 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" - case DoneReasonTokenRepeatLimit: - return "token_repeat_limit" default: return "" // closed } diff --git a/parser/token_parser.go b/parser/token_parser.go new file mode 100644 index 000000000..812458299 --- /dev/null +++ b/parser/token_parser.go @@ -0,0 +1,126 @@ +package parser + +import ( + "encoding/json" + "errors" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/harmony" +) + +type TokenParserType int + +const ( + TokenParserTypeDefault TokenParserType = iota + TokenParserTypeHarmony +) + +type TokenParser struct { + messageHandler MessageHandler + parserEngine ParserInternals + toolParser ToolParser + lastToken string + tokenRepeat int + repeatLimit int +} + +const defaultTokenRepeatLimit = 30 + +type MessageHandler interface { + AddContent(token string) (content, thinking string, toolContent string) +} + +type ParserInternals interface { + AddImplicitStartOrPrefill(prefillString string) +} + +type ToolParser interface { + Add(token string) + Drain() (toolName *string, toolContent string) +} + +// Default implementation for the TokenParser interface as a no-op passthrough +type defaultMessageHandler struct{} + +func (defaultMessageHandler) AddContent(token string) (string, string, string) { + return token, "", "" +} + +type defaultEngine struct{} + +func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} + +type defaultToolParser struct{} + +func (defaultToolParser) Add(token string) {} + +func (defaultToolParser) Drain() (*string, string) { return nil, "" } + +func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser { + switch parserType { + case TokenParserTypeHarmony: + harmonyMessageHandler := harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString) + return TokenParser{ + messageHandler: harmonyMessageHandler, + parserEngine: harmonyMessageHandler.HarmonyParser, + toolParser: harmonyMessageHandler.ToolParser, + repeatLimit: defaultTokenRepeatLimit, + } + + default: + return TokenParser{ + messageHandler: defaultMessageHandler{}, + parserEngine: defaultEngine{}, + toolParser: defaultToolParser{}, + repeatLimit: 30, + } + } +} + +func (p *TokenParser) AddContent(token string) (string, string, error) { + if p.repeatLimitReached(token) { + return "", "", errors.New("token repeat limit reached") + } + content, thinking, toolContent := p.messageHandler.AddContent(token) + p.toolParser.Add(toolContent) + return content, thinking, nil +} + +// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached. +func (p *TokenParser) repeatLimitReached(token string) bool { + if p == nil { + return false + } + trimmed := strings.TrimSpace(token) + if trimmed == p.lastToken { + p.tokenRepeat++ + } else { + p.tokenRepeat = 0 + } + p.lastToken = trimmed + + return p.tokenRepeat >= p.repeatLimit +} + +// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level +func (p *TokenParser) Drain() []api.ToolCall { + toolName, toolContent := p.toolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + return nil + } + return []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }, + } + } + return nil +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a40643ef2..201d55a16 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -30,12 +30,12 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -782,13 +782,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - if req.UseHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - } + tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) if req.Options == nil { opts := api.DefaultOptions() @@ -872,9 +866,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } - var lastToken string - tokenRepeat := 0 - const tokenRepeatLimit = 30 for { select { @@ -883,23 +874,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if strings.TrimSpace(content) == lastToken { - tokenRepeat++ - } - if tokenRepeat == tokenRepeatLimit { - http.Error(w, "token repeat limit reached", http.StatusInternalServerError) - seq.doneReason = llm.DoneReasonTokenRepeatLimit + var thinking string + var err error + content, thinking, err = tokenParser.AddContent(content) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) close(seq.quit) return } - lastToken = strings.TrimSpace(content) - - var thinking string - if harmonyMessageHandler != nil { - var toolContent string - content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) - harmonyToolParser.Add(toolContent) - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, @@ -912,27 +894,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - var toolCalls []api.ToolCall - if harmonyMessageHandler != nil { - // these tools still need to be transformed to the original function name - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) - close(seq.quit) - return - } - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - + toolCalls := tokenParser.Drain() if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ ToolCalls: toolCalls, Done: true, diff --git a/server/routes.go b/server/routes.go index 73ea5fea4..ac4df4a46 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,6 +36,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/openai" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -196,6 +197,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { } useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw + var parserType parser.TokenParserType + if useHarmony { + parserType = parser.TokenParserTypeHarmony + } else { + parserType = parser.TokenParserTypeDefault + } var functionNameMap *harmony.FunctionNameMap if useHarmony { @@ -347,7 +354,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - UseHarmony: useHarmony, + ParserType: parserType, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -1592,6 +1599,12 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = filterThinkTags(msgs, m) useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) + var parserType parser.TokenParserType + if useHarmony { + parserType = parser.TokenParserTypeHarmony + } else { + parserType = parser.TokenParserTypeDefault + } processedTools := req.Tools var functionNameMap *harmony.FunctionNameMap @@ -1662,7 +1675,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - UseHarmony: useHarmony, + ParserType: parserType, PrefillString: prefillString, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ From 17a023f34b9ec81afd39aca36f3dcd488c4d48ea Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 10 Sep 2025 12:05:18 -0700 Subject: [PATCH 02/32] Add v12 + v13 cuda support (#12000) * Add support for upcoming NVIDIA Jetsons The latest Jetsons with JetPack 7 are moving to an SBSA compatible model and will not require building a JetPack specific variant. * cuda: bring back dual versions This adds back dual CUDA versions for our releases, with v11 and v13 to cover a broad set of GPUs and driver versions. * win: break up native builds in build_windows.ps1 * v11 build working on windows and linux * switch to cuda v12.8 not JIT * Set CUDA compression to size * enhance manual install linux docs --- .github/workflows/release.yaml | 13 +++++- .github/workflows/test.yaml | 6 +-- CMakeLists.txt | 6 +-- CMakePresets.json | 26 ++++++++++++ Dockerfile | 30 ++++++++++++-- discover/cuda_common.go | 15 +++---- docs/linux.md | 3 +- scripts/build_windows.ps1 | 72 +++++++++++++++++++++++++++++++--- 8 files changed, 146 insertions(+), 25 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 40871e644..902fa9ccc 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -67,12 +67,21 @@ jobs: install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe cuda-version: '12.8' flags: '' + runner_dir: 'cuda_v12' + - os: windows + arch: amd64 + preset: 'CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cuda-version: '13.0' + flags: '' + runner_dir: 'cuda_v13' - os: windows arch: amd64 preset: 'ROCm 6' install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe rocm-version: '6.2' flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + runner_dir: '' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: @@ -138,7 +147,7 @@ jobs: run: | Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' - cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} + cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}" cmake --build --parallel --preset "${{ matrix.preset }}" cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 env: @@ -232,7 +241,7 @@ jobs: case "$COMPONENT" in bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; - lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; + lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4d8cf773c..a10ad37a9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -46,7 +46,7 @@ jobs: include: - preset: CPU - preset: CUDA - container: nvidia/cuda:12.8.1-devel-ubuntu22.04 + container: nvidia/cuda:13.0.0-devel-ubuntu22.04 flags: '-DCMAKE_CUDA_ARCHITECTURES=87' - preset: ROCm container: rocm/dev-ubuntu-22.04:6.1.2 @@ -78,7 +78,7 @@ jobs: include: - preset: CPU - preset: CUDA - install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe flags: '-DCMAKE_CUDA_ARCHITECTURES=80' - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe @@ -102,7 +102,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait + Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_13.0", "nvcc_13.0", "cublas_13.0", "cublas_dev_13.0")) -NoNewWindow -Wait } $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path diff --git a/CMakeLists.txt b/CMakeLists.txt index d62c8f99f..7cce5e4b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ set(GGML_LLAMAFILE ON) set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128) set(GGML_CUDA_GRAPHS ON) set(GGML_CUDA_FA ON) -set(GGML_CUDA_COMPRESSION_MODE default) +set(GGML_CUDA_COMPRESSION_MODE size) if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+")) @@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") endif() set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama) -set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama) +set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) @@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) install(TARGETS ggml-cuda RUNTIME_DEPENDENCIES - DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR} + DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} PRE_INCLUDE_REGEXES cublas cublasLt cudart PRE_EXCLUDE_REGEXES ".*" RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA diff --git a/CMakePresets.json b/CMakePresets.json index ab2cfe9d6..51190c719 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -18,6 +18,14 @@ "name": "CUDA", "inherits": [ "Default" ] }, + { + "name": "CUDA 11", + "inherits": [ "CUDA" ], + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual", + "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" + } + }, { "name": "CUDA 12", "inherits": [ "CUDA" ], @@ -26,6 +34,14 @@ "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" } }, + { + "name": "CUDA 13", + "inherits": [ "CUDA" ], + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_FLAGS": "-t 2" + } + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], @@ -72,11 +88,21 @@ "configurePreset": "CUDA", "targets": [ "ggml-cuda" ] }, + { + "name": "CUDA 11", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 11" + }, { "name": "CUDA 12", "inherits": [ "CUDA" ], "configurePreset": "CUDA 12" }, + { + "name": "CUDA 13", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 13" + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], diff --git a/Dockerfile b/Dockerfile index 0dc3c1267..c84b52392 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,15 +39,35 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --build --parallel --preset 'CPU' \ && cmake --install build --component CPU --strip --parallel 8 +FROM base AS cuda-11 +ARG CUDA11VERSION=11.8 +RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} +ENV PATH=/usr/local/cuda-11/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \ + && cmake --build --parallel --preset 'CUDA 11' \ + && cmake --install build --component CUDA --strip --parallel 8 + FROM base AS cuda-12 ARG CUDA12VERSION=12.8 RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} ENV PATH=/usr/local/cuda-12/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'CUDA 12' \ + cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\ && cmake --build --parallel --preset 'CUDA 12' \ && cmake --install build --component CUDA --strip --parallel 8 + +FROM base AS cuda-13 +ARG CUDA13VERSION=13.0 +RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} +ENV PATH=/usr/local/cuda-13/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \ + && cmake --build --parallel --preset 'CUDA 13' \ + && cmake --install build --component CUDA --strip --parallel 8 + + FROM base AS rocm-6 ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ @@ -92,10 +112,14 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 -COPY --from=cuda-12 dist/lib/ollama /lib/ollama +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ FROM --platform=linux/arm64 scratch AS arm64 -COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 diff --git a/discover/cuda_common.go b/discover/cuda_common.go index b539f6b32..ca008af63 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -43,14 +43,15 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } } - return "sbsa" } - // driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers - if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { - // The detected driver is older than Feb 2023 - slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) - return "v11" + if gpuInfo.DriverMajor < 13 { + // The detected driver is older than 580 (Aug 2025) + // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance + if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) { + slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) + } + return "v12" } - return "v12" + return "v13" } diff --git a/docs/linux.md b/docs/linux.md index 9a156d1dc..ce5ed860b 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh ## Manual install > [!NOTE] -> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. +> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. Download and extract the package: ```shell curl -LO https://ollama.com/download/ollama-linux-amd64.tgz +sudo rm -rf /usr/lib/ollama sudo tar -C /usr -xzf ollama-linux-amd64.tgz ``` diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 27f3eb9d4..37fe87961 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -78,7 +78,7 @@ function checkEnv() { } -function buildOllama() { +function buildCPU() { mkdir -Force -path "${script:DIST_DIR}\" if ($script:ARCH -ne "arm64") { Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" @@ -90,20 +90,72 @@ function buildOllama() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component CPU --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } +} +function buildCUDA11() { + # CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible + # 19.40 is the last compiler version that works, but recent udpates are 19.43 + # So this pins to MSVC 2019 for best compatibility + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { $hashEnv = @{} Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } - if ("$script:CUDA_DIRS".Contains("v12")) { - $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }} - $env:CUDAToolkit_ROOT=$hashEnv[$v12] - write-host "Building CUDA v12 backend libraries" - & cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR + if ("$script:CUDA_DIRS".Contains("v11")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + write-host "Building CUDA v11 backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + +function buildCUDA12() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v12.8")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + write-host "Building CUDA v12 backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } + } +} + +function buildCUDA13() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v13")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + $env:CUDAToolkit_ROOT=$cuda + write-host "Building CUDA v13 backend libraries $cuda" + & cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + +function buildROCm() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { if ($env:HIP_PATH) { write-host "Building ROCm backend libraries" if (-Not (get-command -ErrorAction silent ninja)) { @@ -129,6 +181,10 @@ function buildOllama() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } } +} + +function buildOllama() { + mkdir -Force -path "${script:DIST_DIR}\" write-host "Building ollama CLI" & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -236,6 +292,10 @@ function distZip() { checkEnv try { if ($($args.count) -eq 0) { + buildCPU + buildCUDA12 + buildCUDA13 + buildROCm buildOllama buildApp gatherDependencies From 5198956372fb2098622fafe143b7026c8ce6ef2d Mon Sep 17 00:00:00 2001 From: "CarbonatedWater.org" Date: Wed, 10 Sep 2025 16:37:10 -0700 Subject: [PATCH 03/32] docs: add ollama-co2 to community integrations (#12230) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 481a29aea..5962f5b28 100644 --- a/README.md +++ b/README.md @@ -414,6 +414,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.) - [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models) - [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare) +- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads) ### Cloud From 71cb86af3e8b8006540550a8eeb9fed106b77eee Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 9 Sep 2025 10:37:28 -0700 Subject: [PATCH 04/32] llm: Remove unneeded warning with flash attention enabled If flash attention is enabled without KV cache quanitization, we will currently always get this warning: level=WARN source=server.go:226 msg="kv cache type not supported by model" type="" --- fs/ggml/ggml.go | 10 +++++++--- llm/memory.go | 2 +- llm/server.go | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 56ad420e5..57476a9a8 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -864,12 +864,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { // SupportsKVCacheType checks if the requested cache type is supported func (f GGML) SupportsKVCacheType(cacheType string) bool { + if cacheType == "" || cacheType == "f16" { + return true + } + if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) { // gpt-oss uses attention with sinks which does not support quantized cache types - slog.Warn("model only supports non-quantized cache types ", "mode", arch) - return cacheType == "f16" + slog.Warn("model only supports non-quantized cache types", "model", arch) + return false } - return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) + return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) } // SupportsFlashAttention checks if the model supports flash attention diff --git a/llm/memory.go b/llm/memory.go index ce128eb58..7a87b28fe 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -202,7 +202,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin var kvct string if useFlashAttention { requested := strings.ToLower(envconfig.KvCacheType()) - if requested != "" && f.SupportsKVCacheType(requested) { + if f.SupportsKVCacheType(requested) { kvct = requested } } diff --git a/llm/server.go b/llm/server.go index 4740a1fd4..a22ae9722 100644 --- a/llm/server.go +++ b/llm/server.go @@ -221,7 +221,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a // Flash Attention also supports kv cache quantization // Enable if the requested and kv cache type is supported by the model - if kvct != "" && f.SupportsKVCacheType(kvct) { + if f.SupportsKVCacheType(kvct) { loadRequest.KvCacheType = kvct } else { slog.Warn("kv cache type not supported by model", "type", kvct) From 29ddfc2cab7f5a83a96c3133094f67b22e4f27d1 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 9 Sep 2025 10:48:34 -0700 Subject: [PATCH 05/32] ggml: Disable flash attention for gemma2 Our new engine implementation of gemma2 doesn't support flash attention, which means that it also doesn't support KV cache quantization. Currently, it is possible to turn these two on, which will result in a crash. --- fs/ggml/ggml.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 57476a9a8..6b582b499 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -883,6 +883,10 @@ func (f GGML) SupportsFlashAttention() bool { return false } + if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) { + return false + } + // Check head counts match and are non-zero headCountK := f.KV().EmbeddingHeadCountK() headCountV := f.KV().EmbeddingHeadCountV() From 8a7e2055d2196df23e86ffe813c1e9287e18068e Mon Sep 17 00:00:00 2001 From: fengyuchuanshen Date: Fri, 12 Sep 2025 00:57:31 +0800 Subject: [PATCH 06/32] cmd: use slices.Contains to simplify code (#12249) --- cmd/cmd.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 8fe068655..19f1e192f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -56,10 +56,8 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) if err != nil { return } - for _, cap := range resp.Capabilities { - if cap == model.CapabilityThinking { - return - } + if slices.Contains(resp.Capabilities, model.CapabilityThinking) { + return } fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name) } From feb18cd710dec1e4754ea56124238a11eb3cb90a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 11 Sep 2025 10:36:10 -0700 Subject: [PATCH 07/32] feat: add dimensions field to embed requests (#12242) * feat: add field to truncate embeddings * add openai embeddings for dimensions --- api/types.go | 4 ++++ docs/api.md | 1 + openai/openai.go | 7 ++++--- server/routes.go | 13 +++++++------ 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/api/types.go b/api/types.go index d3f6fc5a4..a7ddbc373 100644 --- a/api/types.go +++ b/api/types.go @@ -388,8 +388,12 @@ type EmbedRequest struct { // this request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Truncate truncates the input to fit the model's max sequence length. Truncate *bool `json:"truncate,omitempty"` + // Dimensions truncates the output embedding to the specified dimension. + Dimensions int `json:"dimensions,omitempty"` + // Options lists model-specific options. Options map[string]any `json:"options"` } diff --git a/docs/api.md b/docs/api.md index f11d59ed1..f47af63c6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1708,6 +1708,7 @@ Advanced parameters: - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) +- `dimensions`: number of dimensions for the embedding ### Examples diff --git a/openai/openai.go b/openai/openai.go index 9c7c41cb4..b6a8a95e2 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -76,8 +76,9 @@ type JsonSchema struct { } type EmbedRequest struct { - Input any `json:"input"` - Model string `json:"model"` + Input any `json:"input"` + Model string `json:"model"` + Dimensions int `json:"dimensions,omitempty"` } type StreamOptions struct { @@ -1005,7 +1006,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc { } var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } diff --git a/server/routes.go b/server/routes.go index ac4df4a46..8dd1b217a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -558,7 +558,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { if err != nil { return err } - embeddings[i] = normalize(embedding) + // TODO: this first normalization should be done by the model + embedding = normalize(embedding) + if req.Dimensions > 0 && req.Dimensions < len(embedding) { + embedding = normalize(embedding[:req.Dimensions]) + } + embeddings[i] = embedding return nil }) } @@ -584,11 +589,7 @@ func normalize(vec []float32) []float32 { sum += v * v } - norm := float32(0.0) - if sum > 0 { - norm = float32(1.0 / math.Sqrt(float64(sum))) - } - + norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12)) for i := range vec { vec[i] *= norm } From eb10390de96ad6f5c21bc9e61f6cd222405f627a Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 11 Sep 2025 10:30:18 -0700 Subject: [PATCH 08/32] llm: Enable new memory estimates by default New memory estimates (see #11090 for more information) are now enabled automatically for all models running on the Ollama engine, improving both stability and performance through more accurate sizing and allocation. Models running on the llama engine will continue to use the original style of memory estimation. --- envconfig/config.go | 3 --- llm/server.go | 7 +------ 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 868813ae8..7fc018870 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -185,8 +185,6 @@ var ( ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) // Auth enables authentication between the Ollama client and server UseAuth = Bool("OLLAMA_AUTH") - // Enable the new memory estimation logic - NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES") ) func String(s string) func() string { @@ -272,7 +270,6 @@ func AsMap() map[string]EnvVar { "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, - "OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/llm/server.go b/llm/server.go index a22ae9722..5caf19875 100644 --- a/llm/server.go +++ b/llm/server.go @@ -162,11 +162,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } } - newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates() - if newEstimates { - slog.Info("enabling new memory estimates") - } - // Verify the requested context size is <= the model training size trainCtx := f.KV().ContextLength() if opts.NumCtx > int(trainCtx) && trainCtx > 0 { @@ -434,7 +429,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } }() - if newEstimates { + if textProcessor != nil { return &ollamaServer{llmServer: s}, nil } else { return &llamaServer{llmServer: s, ggml: f}, nil From aba157531521192a04d09811fac3cda20e1a8340 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 10 Sep 2025 11:03:06 -0700 Subject: [PATCH 09/32] llm: Don't try to load split vision models in the Ollama engine If a model with a split vision projector is loaded in the Ollama engine, the projector will be ignored and the model will hallucinate a response. Instead, fallback and try to load the model in the llama engine. --- llm/server.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llm/server.go b/llm/server.go index 5caf19875..9100b6978 100644 --- a/llm/server.go +++ b/llm/server.go @@ -149,7 +149,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a var textProcessor model.TextProcessor var err error if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { - textProcessor, err = model.NewTextProcessor(modelPath) + if len(projectors) == 0 { + textProcessor, err = model.NewTextProcessor(modelPath) + } else { + err = errors.New("split vision models aren't supported") + } if err != nil { // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) From 61fb912ca46fe902180892316f6cc34adda07b67 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 11 Sep 2025 12:25:26 -0700 Subject: [PATCH 10/32] CI: fix windows cuda build (#12246) * ci: adjust cuda component list v13 has a different breakdown of the components required to build ollama * review comments --- .github/workflows/release.yaml | 15 ++++++++++++++- .github/workflows/test.yaml | 12 +++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 902fa9ccc..fc3cde9c9 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -65,6 +65,11 @@ jobs: arch: amd64 preset: 'CUDA 12' install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' cuda-version: '12.8' flags: '' runner_dir: 'cuda_v12' @@ -72,6 +77,14 @@ jobs: arch: amd64 preset: 'CUDA 13' install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' cuda-version: '13.0' flags: '' runner_dir: 'cuda_v13' @@ -105,7 +118,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - $subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} + $subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait } diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a10ad37a9..e470540a2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -80,6 +80,15 @@ jobs: - preset: CUDA install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' @@ -102,7 +111,8 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_13.0", "nvcc_13.0", "cublas_13.0", "cublas_dev_13.0")) -NoNewWindow -Wait + $subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"} + Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait } $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path From 26214125e86ac1d4512dff68c983137589cfddbf Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 11 Sep 2025 13:48:51 -0700 Subject: [PATCH 11/32] ollamarunner: Suppress stack trace during memory allocation Allocation failures can be a normal part of new memory estimates, so we shouldn't print a stack trace in this case. --- runner/ollamarunner/runner.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 201d55a16..676e5186f 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -18,7 +18,6 @@ import ( "reflect" "regexp" "runtime" - "runtime/debug" "strconv" "strings" "sync" @@ -1101,9 +1100,13 @@ func (s *Server) allocModel( // Convert memory allocation panics to errors defer func() { if r := recover(); r != nil { - debug.PrintStack() if err, ok := r.(error); ok { - panicErr = err + var noMem ml.ErrNoMem + if errors.As(err, &noMem) { + panicErr = noMem + } else { + panic(r) + } } else { panic(r) } From e4ce68311a64310ece5534ae3a4820b20ea3d42f Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 12 Sep 2025 07:59:14 -0700 Subject: [PATCH 12/32] cuda: remove compression for better compatibility (#12259) This retains compatibility with driver 531 and up at the trade-off of space. --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7cce5e4b1..198fcdeb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ set(GGML_LLAMAFILE ON) set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128) set(GGML_CUDA_GRAPHS ON) set(GGML_CUDA_FA ON) -set(GGML_CUDA_COMPRESSION_MODE size) +set(GGML_CUDA_COMPRESSION_MODE default) if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+")) From 44a679287366daf04c56c61fe0ab135de7de94c2 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 12 Sep 2025 13:59:34 -0700 Subject: [PATCH 13/32] tests: tighten up a few flaky tests (#12271) Sometimes the context test results are pure emoji's Thanksgiving has too much variability, so swap for a more straight forward prompt. --- integration/context_test.go | 2 +- integration/utils_test.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integration/context_test.go b/integration/context_test.go index ca6f16087..15c157858 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "Write me a story with a ton of emojis?", + Prompt: "Write me a story in english with a lot of emojis", Stream: &stream, Options: map[string]any{ "temperature": 0, diff --git a/integration/utils_test.go b/integration/utils_test.go index ec74b2e3d..7901fed3f 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { Model: smol, - Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply", + Prompt: "how do rainbows form? Be brief but factual in your reply", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { @@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { [][]string{ {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, - {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"}, + {"water", "droplet", "refracted", "reflect", "color", "spectrum"}, {"fourth", "july", "declaration", "independence"}, - {"nitrogen", "oxygen", "carbon", "dioxide"}, + {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"}, } } From 053092185eda0ac5272400ac4d4be135944800f7 Mon Sep 17 00:00:00 2001 From: tc-mb <157115220+tc-mb@users.noreply.github.com> Date: Sat, 13 Sep 2025 07:25:12 +0800 Subject: [PATCH 14/32] Fix image cannot be seen with slice image on llama engine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ollama's recent engine update, llama.cpp, caused all models requiring a slice schema to not display images. As a result, the value of numTokens isn't always the length of the sliced ​​image embed, but rather the end length of the schema. This causes the image embed to not be correctly included during all slice processing. --- llama/llama.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index ac2c112c2..88672a033 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, } nChunks := C.mtmd_input_chunks_size(ic) numEmbed := llamaContext.Model().NEmbd() - lastChunkSize := 0 + embed := make([][]float32, 0) for i := range int(nChunks) { chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) - lastChunkSize = numTokens + slog.Debug("chunk tokens", "index", i, "numTokens", numTokens) // Encode the chunk if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { return nil, errors.New("unable to encode mtmd image chunk") } - } - // Get the embeddings - embed := make([][]float32, lastChunkSize) - embd := C.mtmd_get_output_embd(c.c) - if nil == embd { - return nil, errors.New("failed to get image embedding") - } + // Get the embeddings for this chunk + chunkEmbed := make([][]float32, numTokens) + chunkEmbd := C.mtmd_get_output_embd(c.c) + if nil == chunkEmbd { + continue + } - // Extend the embedding array for each token - s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) - rows := make([]float32, len(s)) - copy(rows, s) - for i := range lastChunkSize { - embed[i] = rows[i*numEmbed : (i+1)*numEmbed] + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) + rows := make([]float32, len(s)) + copy(rows, s) + for i := range numTokens { + chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + embed = append(embed, chunkEmbed...) } - + slog.Debug("image embeddings", "totalEmbeddings", len(embed)) return embed, nil } From 9d56e63dbf369599007876f207570fce21683030 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 12 Sep 2025 13:32:02 -0700 Subject: [PATCH 15/32] Revert "runner: simplify parser entrypoints in runner (#12233)" This reverts commit 8d6fffaead722d86622b17c059051796d2621858. --- harmony/harmonyparser.go | 18 ++--- harmony/harmonyparser_test.go | 28 ++++---- llm/server.go | 7 +- parser/token_parser.go | 126 ---------------------------------- runner/ollamarunner/runner.go | 54 ++++++++++++--- server/routes.go | 17 +---- 6 files changed, 77 insertions(+), 173 deletions(-) delete mode 100644 parser/token_parser.go diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index 3ec2c21f1..addce4c94 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -289,7 +289,6 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap - ToolParser *HarmonyToolCallAccumulator } // NewHarmonyMessageHandler creates a new message handler @@ -302,16 +301,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), - ToolParser: &HarmonyToolCallAccumulator{ - state: harmonyToolCallState_Normal, - currentToolName: nil, - }, } } // AddContent processes the content and returns the content, thinking, and tool content. // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser -func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { +func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { contentSb := strings.Builder{} thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} @@ -328,14 +323,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri // event.Header.Recipient is the tool name, something like // "browser.search" for a built-in, or "functions.calc" for a // custom one - h.ToolParser.SetToolName(event.Header.Recipient) + toolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Thinking } case "commentary": if event.Header.Recipient != "" { h.state = harmonyMessageState_ToolCalling - h.ToolParser.SetToolName(event.Header.Recipient) + toolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Normal } @@ -358,6 +353,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri return contentSb.String(), thinkingSb.String(), toolContentSb.String() } +func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { + return &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + } +} + type harmonyToolCallState int const ( diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index 82bf5b2de..dcf1af4e8 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -541,7 +541,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_then_content_streams", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() type step struct { in string wantContent string @@ -554,7 +554,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { {in: "<|end|>", wantContent: ""}, } for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in) + content, thinking, tool := handler.AddContent(s.in, tp) if tool != "" { tp.Add(tool) } @@ -567,7 +567,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("content_streams_as_it_arrives", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|start|>assistant<|message|>Hello", ", world", @@ -575,7 +575,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -595,7 +595,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_streams_separately_from_content", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>analysis<|message|>Thinking...", "<|end|>", @@ -604,7 +604,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -624,7 +624,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|chan", "nel|>analysis<|mess", @@ -637,7 +637,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { var thinkingPieces []string var contentPieces []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -659,7 +659,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("simple_assistant_after_analysis", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>analysis<|message|>Think", "<|end|>", @@ -668,7 +668,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var contentSb, thinkingSb strings.Builder for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -686,12 +686,12 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if content != "" || thinking != "" { continue } @@ -711,14 +711,14 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_across_chunks", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", "2\"}", "<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if content != "" || thinking != "" { continue } diff --git a/llm/server.go b/llm/server.go index 9100b6978..45a9ad14c 100644 --- a/llm/server.go +++ b/llm/server.go @@ -35,7 +35,6 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" - "github.com/ollama/ollama/parser" ) type filteredEnv []string @@ -1350,7 +1349,7 @@ type CompletionRequest struct { Options *api.Options Grammar string // set before sending the request to the subprocess - ParserType parser.TokenParserType + UseHarmony bool PrefillString string } @@ -1364,6 +1363,8 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed + // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit + DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1372,6 +1373,8 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" + case DoneReasonTokenRepeatLimit: + return "token_repeat_limit" default: return "" // closed } diff --git a/parser/token_parser.go b/parser/token_parser.go deleted file mode 100644 index 812458299..000000000 --- a/parser/token_parser.go +++ /dev/null @@ -1,126 +0,0 @@ -package parser - -import ( - "encoding/json" - "errors" - "strings" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/harmony" -) - -type TokenParserType int - -const ( - TokenParserTypeDefault TokenParserType = iota - TokenParserTypeHarmony -) - -type TokenParser struct { - messageHandler MessageHandler - parserEngine ParserInternals - toolParser ToolParser - lastToken string - tokenRepeat int - repeatLimit int -} - -const defaultTokenRepeatLimit = 30 - -type MessageHandler interface { - AddContent(token string) (content, thinking string, toolContent string) -} - -type ParserInternals interface { - AddImplicitStartOrPrefill(prefillString string) -} - -type ToolParser interface { - Add(token string) - Drain() (toolName *string, toolContent string) -} - -// Default implementation for the TokenParser interface as a no-op passthrough -type defaultMessageHandler struct{} - -func (defaultMessageHandler) AddContent(token string) (string, string, string) { - return token, "", "" -} - -type defaultEngine struct{} - -func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} - -type defaultToolParser struct{} - -func (defaultToolParser) Add(token string) {} - -func (defaultToolParser) Drain() (*string, string) { return nil, "" } - -func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser { - switch parserType { - case TokenParserTypeHarmony: - harmonyMessageHandler := harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString) - return TokenParser{ - messageHandler: harmonyMessageHandler, - parserEngine: harmonyMessageHandler.HarmonyParser, - toolParser: harmonyMessageHandler.ToolParser, - repeatLimit: defaultTokenRepeatLimit, - } - - default: - return TokenParser{ - messageHandler: defaultMessageHandler{}, - parserEngine: defaultEngine{}, - toolParser: defaultToolParser{}, - repeatLimit: 30, - } - } -} - -func (p *TokenParser) AddContent(token string) (string, string, error) { - if p.repeatLimitReached(token) { - return "", "", errors.New("token repeat limit reached") - } - content, thinking, toolContent := p.messageHandler.AddContent(token) - p.toolParser.Add(toolContent) - return content, thinking, nil -} - -// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached. -func (p *TokenParser) repeatLimitReached(token string) bool { - if p == nil { - return false - } - trimmed := strings.TrimSpace(token) - if trimmed == p.lastToken { - p.tokenRepeat++ - } else { - p.tokenRepeat = 0 - } - p.lastToken = trimmed - - return p.tokenRepeat >= p.repeatLimit -} - -// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level -func (p *TokenParser) Drain() []api.ToolCall { - toolName, toolContent := p.toolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - return nil - } - return []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }, - } - } - return nil -} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 676e5186f..5da8ca3cb 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -29,12 +29,12 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -781,7 +781,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator + if req.UseHarmony { + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + } if req.Options == nil { opts := api.DefaultOptions() @@ -865,6 +871,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } + var lastToken string + tokenRepeat := 0 + const tokenRepeatLimit = 30 for { select { @@ -873,14 +882,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - var thinking string - var err error - content, thinking, err = tokenParser.AddContent(content) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + if strings.TrimSpace(content) == lastToken { + tokenRepeat++ + } + if tokenRepeat == tokenRepeatLimit { + http.Error(w, "token repeat limit reached", http.StatusInternalServerError) + seq.doneReason = llm.DoneReasonTokenRepeatLimit close(seq.quit) return } + lastToken = strings.TrimSpace(content) + + var thinking string + if harmonyMessageHandler != nil { + var toolContent string + content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) + harmonyToolParser.Add(toolContent) + } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, @@ -893,7 +911,27 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - toolCalls := tokenParser.Drain() + var toolCalls []api.ToolCall + if harmonyMessageHandler != nil { + // these tools still need to be transformed to the original function name + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) + close(seq.quit) + return + } + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ ToolCalls: toolCalls, Done: true, diff --git a/server/routes.go b/server/routes.go index 8dd1b217a..da5e22f68 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,7 +36,6 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -197,12 +196,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw - var parserType parser.TokenParserType - if useHarmony { - parserType = parser.TokenParserTypeHarmony - } else { - parserType = parser.TokenParserTypeDefault - } var functionNameMap *harmony.FunctionNameMap if useHarmony { @@ -354,7 +347,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - ParserType: parserType, + UseHarmony: useHarmony, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -1600,12 +1593,6 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = filterThinkTags(msgs, m) useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) - var parserType parser.TokenParserType - if useHarmony { - parserType = parser.TokenParserTypeHarmony - } else { - parserType = parser.TokenParserTypeDefault - } processedTools := req.Tools var functionNameMap *harmony.FunctionNameMap @@ -1676,7 +1663,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - ParserType: parserType, + UseHarmony: useHarmony, PrefillString: prefillString, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ From 92b96d54efd6b49322b7cf046f9a0dc16b00cd0a Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 12 Sep 2025 13:32:30 -0700 Subject: [PATCH 16/32] Revert "runner: move harmony to runner (#12052)" This reverts commit 1a558f98e2d07885efb6cf82943ae029c647f3d0. --- harmony/harmonyparser.go | 46 ++--- harmony/harmonyparser_test.go | 200 -------------------- llm/server.go | 37 ++-- runner/ollamarunner/runner.go | 55 +----- server/routes.go | 131 ++++++++----- server/routes_harmony_streaming_test.go | 237 ++++++++++++++++++++++-- 6 files changed, 337 insertions(+), 369 deletions(-) diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index addce4c94..a51819dda 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -3,29 +3,15 @@ package harmony import ( "fmt" "log/slog" - "slices" "strings" "unicode" "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/template" ) type harmonyParserState int -func ShouldUseHarmony(modelFamily string, template *template.Template) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if template.Contains("<|start|>") && template.Contains("<|end|>") { - return true - } - } - - return false -} - const ( harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_ParsingHeader @@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() { s.acc.WriteString("<|start|>assistant") } -func Prefill(lastMessage api.Message) string { - if lastMessage.Role != "assistant" { - return "" - } - - switch { - case strings.TrimSpace(lastMessage.Content) != "": - return "<|start|>assistant<|channel|>final<|message|>" - case strings.TrimSpace(lastMessage.Thinking) != "": - return "<|start|>assistant<|channel|>analysis<|message|>" - default: - return "" - } -} - -// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided -func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) { - if strings.TrimSpace(prefillString) != "" { - s.acc.WriteString(prefillString) - } else { - s.AddImplicitStart() +func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { + if lastMessage != nil && lastMessage.Role == "assistant" { + // handle prefilling conditions + if lastMessage.Content != "" { + s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") + return + } else if lastMessage.Thinking != "" { + s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") + return + } } + s.AddImplicitStart() } func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index dcf1af4e8..b988a018f 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -3,7 +3,6 @@ package harmony import ( "fmt" "reflect" - "strings" "testing" ) @@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) { }) } } - -func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { - t.Run("thinking_then_content_streams", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - type step struct { - in string - wantContent string - wantThinking string - } - steps := []step{ - {in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."}, - {in: "<|end|>", wantThinking: ""}, - {in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"}, - {in: "<|end|>", wantContent: ""}, - } - for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in, tp) - if tool != "" { - tp.Add(tool) - } - if content != s.wantContent || thinking != s.wantThinking { - t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking) - } - } - }) - - t.Run("content_streams_as_it_arrives", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|start|>assistant<|message|>Hello", - ", world", - "!<|end|>", - } - var got []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - t.Fatalf("unexpected thinking %q", thinking) - } - if content != "" { - got = append(got, content) - } - } - want := []string{"Hello", ", world", "!"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("content pieces mismatch: got %v want %v", got, want) - } - }) - - t.Run("thinking_streams_separately_from_content", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>analysis<|message|>Thinking...", - "<|end|>", - "<|start|>assistant<|message|>Answer", - "<|end|>", - } - var got []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - got = append(got, thinking) - } - if content != "" { - got = append(got, content) - } - } - want := []string{"Thinking...", "Answer"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("content pieces mismatch: got %v want %v", got, want) - } - }) - - t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|chan", - "nel|>analysis<|mess", - "age|>Deep ", - "thought", - "<|end|>", - "<|start|>assistant<|message|>Done", - "<|end|>", - } - var thinkingPieces []string - var contentPieces []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - thinkingPieces = append(thinkingPieces, thinking) - } - if content != "" { - contentPieces = append(contentPieces, content) - } - } - if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) { - t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want) - } - if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) { - t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want) - } - }) - - t.Run("simple_assistant_after_analysis", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>analysis<|message|>Think", - "<|end|>", - "<|start|>assistant<|message|>Answer", - "<|end|>", - } - var contentSb, thinkingSb strings.Builder - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - contentSb.WriteString(content) - thinkingSb.WriteString(thinking) - } - if contentSb.String() != "Answer" { - t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer") - } - if thinkingSb.String() != "Think" { - t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think") - } - }) - - t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", - } - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if content != "" || thinking != "" { - continue - } - if tool != "" { - tp.Add(tool) - } - } - name, args := tp.Drain() - if name == nil || *name != "functions.calculate" { - t.Fatalf("unexpected tool name: %v", name) - } - if got, want := args, "{\"expression\":\"2+2\"}"; got != want { - t.Fatalf("unexpected tool args: got %s want %s", got, want) - } - }) - - t.Run("tool_call_across_chunks", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", - "2\"}", - "<|end|>", - } - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if content != "" || thinking != "" { - continue - } - if tool != "" { - tp.Add(tool) - } - } - name, args := tp.Drain() - if name == nil || *name != "functions.calculate" { - t.Fatalf("unexpected tool name: %v", name) - } - if got, want := args, "{\"expression\":\"2+2\"}"; got != want { - t.Fatalf("unexpected tool args: got %s want %s", got, want) - } - }) -} diff --git a/llm/server.go b/llm/server.go index 45a9ad14c..75f049bc0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1348,9 +1348,7 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess - UseHarmony bool - PrefillString string + Grammar string // set before sending the request to the subprocess } // DoneReason represents the reason why a completion response is done @@ -1363,8 +1361,6 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed - // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit - DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1373,23 +1369,19 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" - case DoneReasonTokenRepeatLimit: - return "token_repeat_limit" default: return "" // closed } } type CompletionResponse struct { - Content string `json:"content"` - Thinking string `json:"thinking"` - ToolCalls []api.ToolCall `json:"tool_calls"` - DoneReason DoneReason `json:"done_reason"` - Done bool `json:"done"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration time.Duration `json:"eval_duration"` + Content string `json:"content"` + DoneReason DoneReason `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -1507,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } switch { - // TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future - case strings.TrimSpace(c.Content) == lastToken && c.Content != "": + case strings.TrimSpace(c.Content) == lastToken: tokenRepeat++ default: lastToken = strings.TrimSpace(c.Content) @@ -1521,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return ctx.Err() } + if c.Content != "" { + fn(CompletionResponse{ + Content: c.Content, + }) + } + if c.Done { fn(c) return nil } - - if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 { - fn(c) - } } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 5da8ca3cb..1081a1f55 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -29,7 +29,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" @@ -781,14 +780,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - if req.UseHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - } - if req.Options == nil { opts := api.DefaultOptions() req.Options = &opts @@ -871,9 +862,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } - var lastToken string - tokenRepeat := 0 - const tokenRepeatLimit = 30 for { select { @@ -882,27 +870,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if strings.TrimSpace(content) == lastToken { - tokenRepeat++ - } - if tokenRepeat == tokenRepeatLimit { - http.Error(w, "token repeat limit reached", http.StatusInternalServerError) - seq.doneReason = llm.DoneReasonTokenRepeatLimit - close(seq.quit) - return - } - lastToken = strings.TrimSpace(content) - - var thinking string - if harmonyMessageHandler != nil { - var toolContent string - content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) - harmonyToolParser.Add(toolContent) - } - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - Thinking: thinking, + Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) @@ -911,29 +880,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - var toolCalls []api.ToolCall - if harmonyMessageHandler != nil { - // these tools still need to be transformed to the original function name - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) - close(seq.quit) - return - } - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - ToolCalls: toolCalls, Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, diff --git a/server/routes.go b/server/routes.go index da5e22f68..5114cb74f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -46,6 +46,18 @@ import ( "github.com/ollama/ollama/version" ) +func shouldUseHarmony(model *Model) bool { + if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { + return true + } + } + + return false +} + func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -195,11 +207,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw - var functionNameMap *harmony.FunctionNameMap - + useHarmony := shouldUseHarmony(m) && !req.Raw + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator if useHarmony { - functionNameMap = harmony.NewFunctionNameMap() + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStart() + harmonyToolParser = harmonyMessageHandler.CreateToolParser() } // Validate Think value: string values currently only allowed for gptoss models @@ -343,19 +357,16 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - UseHarmony: useHarmony, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Response: cr.Content, Done: cr.Done, - Thinking: cr.Thinking, - ToolCalls: cr.ToolCalls, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -364,22 +375,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if res.Done { - res.DoneReason = cr.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - } - if useHarmony { - for i, tool := range res.ToolCalls { - res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) - } - if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { - ch <- res - } - return - } - if thinkingState != nil { + content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + res.Response = content + res.Thinking = thinking + harmonyToolParser.Add(toolContent) + } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content @@ -390,6 +391,30 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + if useHarmony { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) + ch <- gin.H{"error": errStr} + return + } + + res.ToolCalls = append(res.ToolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + + res.DoneReason = cr.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + if !req.Raw { tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) if err != nil { @@ -1592,21 +1617,27 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator + + useHarmony := shouldUseHarmony(m) processedTools := req.Tools - var functionNameMap *harmony.FunctionNameMap - var prefillString string - // TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner if useHarmony { - prefillString = harmony.Prefill(msgs[len(msgs)-1]) - functionNameMap = harmony.NewFunctionNameMap() + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + // make a copy of tools to pass to the chat prompt. Function names may be // renamed to be valid Harmony function names. processedTools = make([]api.Tool, len(req.Tools)) copy(processedTools, req.Tools) for i, tool := range processedTools { - processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) + processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) } } @@ -1659,17 +1690,15 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - UseHarmony: useHarmony, - PrefillString: prefillString, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, + Message: api.Message{Role: "assistant", Content: r.Content}, Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -1685,13 +1714,31 @@ func (s *Server) ChatHandler(c *gin.Context) { } if useHarmony { - for i, tool := range res.Message.ToolCalls { - res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) + content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) + res.Message.Content = content + res.Message.Thinking = thinking + harmonyToolParser.Add(toolContent) + + if r.Done { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) + ch <- gin.H{"error": errStr} + return + } + res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} + } } + // only send messages with meaningful content (empty messages confuse clients) if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { ch <- res } + return } diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index bcb020886..b1ede4e39 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/json" + "net/http" "strings" "testing" "time" @@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "content streams as it arrives", steps: []step{ { - input: llm.CompletionResponse{Content: "Hello", Done: false}, + input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false}, wantContent: "Hello", }, { @@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { wantContent: ", world", }, { - input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "!", }, }, @@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "thinking streams separately from content", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false}, wantThinking: "Thinking...", }, { - input: llm.CompletionResponse{Content: "Answer", Done: false}, - wantContent: "Answer", + input: llm.CompletionResponse{Content: "<|end|>", Done: false}, + // No output expected - just closes the analysis message and resets state to normal }, { - input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false}, + wantContent: "Answer", // After message end, state is reset to normal + }, + { + input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + // No output expected - just closes the assistant message }, }, }, @@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "partial tags buffer until complete", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, + input: llm.CompletionResponse{Content: "<|chan", Done: false}, + // No output - partial tag + }, + { + input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false}, + // No output - still building tags + }, + { + input: llm.CompletionResponse{Content: "age|>Deep ", Done: false}, wantThinking: "Deep ", }, { - input: llm.CompletionResponse{Thinking: "thought", Done: false}, + input: llm.CompletionResponse{Content: "thought<|end|>", Done: false}, wantThinking: "thought", }, { - input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, - wantContent: "Done", + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Done", // After message end, state is reset to normal }, }, }, @@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "simple assistant after analysis", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "Answer", wantThinking: "Think", }, @@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call parsed and returned correctly", steps: []step{ { - input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "The weather is sunny", wantToolCalls: []api.ToolCall{ { @@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call with streaming JSON across chunks", steps: []step{ { - input: llm.CompletionResponse{Done: false}, + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false}, + // No output yet - incomplete JSON }, { - input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, + input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false}, + // Still no output - incomplete JSON + }, + { + input: llm.CompletionResponse{Content: "2\"}", Done: true}, wantToolCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ @@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { gin.SetMode(gin.TestMode) mockResponses := []llm.CompletionResponse{ - {Content: "First ", Done: false}, + {Content: "<|message|>First ", Done: false}, {Content: "chunk ", Done: false}, - {Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, + {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, } mock := mockRunner{ @@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) } } + +func TestChatHarmonyParserStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + type expectedChunk struct { + afterResponse int // Which mock response this chunk should appear after + content string // Expected content in this chunk + thinking string // Expected thinking in this chunk + } + + testCases := []struct { + name string + mockResponses []llm.CompletionResponse + expectedChunks []expectedChunk + wantContent string + wantThinking string + }{ + { + name: "simple message without thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|start|>assistant<|message|>Hello, ", Done: false}, + {Content: "how can I help?", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 1, content: "Hello, "}, + {afterResponse: 2, content: "how can I help?"}, + }, + wantContent: "Hello, how can I help?", + }, + { + name: "message with analysis channel for thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|channel|>analysis<|message|>", Done: false}, + {Content: "Let me think ", Done: false}, + {Content: "about this problem...", Done: false}, + {Content: "<|end|>", Done: false}, + {Content: "<|start|>assistant<|message|>", Done: false}, + {Content: "The answer ", Done: false}, + {Content: "is 42", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 2, thinking: "Let me think "}, + {afterResponse: 3, thinking: "about this problem..."}, + {afterResponse: 6, content: "The answer "}, + {afterResponse: 7, content: "is 42"}, + }, + wantContent: "The answer is 42", + wantThinking: "Let me think about this problem...", + }, + { + name: "streaming with partial tags across boundaries", + mockResponses: []llm.CompletionResponse{ + {Content: "<|chan", Done: false}, + {Content: "nel|>analy", Done: false}, + {Content: "sis<|mess", Done: false}, + {Content: "age|>Think", Done: false}, + {Content: "ing deeply...<|end|>", Done: false}, + {Content: "<|start|>assi", Done: false}, + {Content: "stant<|message|>Result ", Done: false}, + {Content: "computed<|e", Done: false}, + {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 4, thinking: "Think"}, + {afterResponse: 5, thinking: "ing deeply..."}, + {afterResponse: 7, content: "Result "}, + {afterResponse: 8, content: "computed"}, + }, + wantContent: "Result computed", + wantThinking: "Thinking deeply...", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Channel to synchronize mock responses with chunk verification + responsesSent := make(chan int, len(tc.mockResponses)) + + mock := mockRunner{ + CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + // Send mock responses one at a time, notifying when each is sent + for i, resp := range tc.mockResponses { + fn(resp) + responsesSent <- i + 1 + } + close(responsesSent) + return nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { + req.successCh <- &runnerRef{ + llama: &mock, + } + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a minimal model + _, digest := createHarmonyTestModel(t) + + // Create model with passthrough template + stream := false + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "harmony-test", + Files: map[string]string{"file.gguf": digest}, + Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("failed to create model: %d", w.Code) + } + + // Test chat endpoint with streaming + streamTrue := true + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "harmony-test", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Stream: &streamTrue, + Tools: getTestTools(), + }) + + if w.Code != http.StatusOK { + t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) + } + + // Parse streaming response + var chunks []api.ChatResponse + var content, thinking strings.Builder + + decoder := json.NewDecoder(w.Body) + for decoder.More() { + var chunk api.ChatResponse + if err := decoder.Decode(&chunk); err != nil { + t.Fatalf("failed to decode chunk: %v", err) + } + chunks = append(chunks, chunk) + + // Accumulate content and thinking from each chunk + content.WriteString(chunk.Message.Content) + thinking.WriteString(chunk.Message.Thinking) + + // Debug output + t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) + } + + // Verify we got streaming chunks + if len(chunks) == 0 { + t.Fatal("expected streaming chunks, got none") + } + + gotContent := content.String() + gotThinking := thinking.String() + + if gotContent != tc.wantContent { + t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent) + } + if gotThinking != tc.wantThinking { + t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking) + } + + // Verify last chunk has done=true + lastChunk := chunks[len(chunks)-1] + if !lastChunk.Done { + t.Error("expected last chunk to have done=true") + } + }) + } +} From 47991940d44d8c3db3a7f0d36135976de7aadf81 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 11 Sep 2025 13:40:35 -0700 Subject: [PATCH 17/32] add qwen3-coder tool support The format qwen3-coder uses is relatively unique, both in rendering and in parsing. To implement parsing, I wrote a custom parser in similar style to harmony. For the rendering, I found that the logic would be much more difficult to follow in a template, so I introduced the concept of a built-in renderer that uses go code, rather than a template to generate prompts. I set us up for future built-in parsers and renderers by making it so they can be specified in a Modelfile like so: ``` RENDERER "qwen3-coder" PARSER "qwen3-coder" ``` These need to be provided explicitly because the architecture alone is not enough to understand what format the model expects to receive, and what format we expect it to output (e.g., qwen3-coder is `qwen3moe`, which includes other qwen3-family models as well) I haven't converted harmony to be one of these "built-ins" yet, since some of it is in flux with the changes @ParthSareen has been making to move harmony to the runner. It is likely that many other built-ins will need to move to the runner as well, but I'm able to slightly defer that decision since qwen3-coder doesn't have thinking (and therefore doesn't need to be in the runner to make structured outputs work). I expect to unify harmony with this approach very soon. Whether a particular model supports tools or thinking was previously inferred from templates, but without a template we now also use the parser itself to declare what it supports. If we have future models that re-use the same parsing format, but have different capabilities, we'll want to parameterize them and give them different names to be specified as a `PARSER`. Misc changes: - I worked on the renderer by diffing outputs from the reference implementation and ours. To make it easier to do this, I extended to also support returning the prompt via the openai compat layer --- api/types.go | 22 +- model/parsers/parsers.go | 37 ++ model/parsers/qwen3coder.go | 410 ++++++++++++++ model/parsers/qwen3coder_test.go | 830 +++++++++++++++++++++++++++++ model/renderers/qwen3coder.go | 217 ++++++++ model/renderers/qwen3coder_test.go | 338 ++++++++++++ model/renderers/renderer.go | 26 + openai/openai.go | 47 +- parser/parser.go | 10 +- parser/parser_test.go | 28 + server/create.go | 2 + server/images.go | 23 +- server/prompt.go | 37 +- server/routes.go | 38 +- server/routes_debug_test.go | 4 +- 15 files changed, 2012 insertions(+), 57 deletions(-) create mode 100644 model/parsers/parsers.go create mode 100644 model/parsers/qwen3coder.go create mode 100644 model/parsers/qwen3coder_test.go create mode 100644 model/renderers/qwen3coder.go create mode 100644 model/renderers/qwen3coder_test.go create mode 100644 model/renderers/renderer.go diff --git a/api/types.go b/api/types.go index a7ddbc373..df3504c3b 100644 --- a/api/types.go +++ b/api/types.go @@ -313,10 +313,11 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message Message `json:"message"` + DoneReason string `json:"done_reason,omitempty"` + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` Done bool `json:"done"` @@ -329,13 +330,6 @@ type DebugInfo struct { ImageCount int `json:"image_count,omitempty"` } -// DebugTemplateResponse is returned when _debug_render_only is set to true -type DebugTemplateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - DebugInfo DebugInfo `json:"_debug_info"` -} - type Metrics struct { TotalDuration time.Duration `json:"total_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` @@ -443,6 +437,8 @@ type CreateRequest struct { System string `json:"system,omitempty"` Parameters map[string]any `json:"parameters,omitempty"` Messages []Message `json:"messages,omitempty"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -480,6 +476,8 @@ type ShowResponse struct { Parameters string `json:"parameters,omitempty"` Template string `json:"template,omitempty"` System string `json:"system,omitempty"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` Details ModelDetails `json:"details,omitempty"` Messages []Message `json:"messages,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"` @@ -592,6 +590,8 @@ type GenerateResponse struct { Metrics ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` } // ModelDetails provides details about a model. diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go new file mode 100644 index 000000000..001cac442 --- /dev/null +++ b/model/parsers/parsers.go @@ -0,0 +1,37 @@ +package parsers + +import ( + "github.com/ollama/ollama/api" +) + +type BuiltinParser interface { + Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) + HasToolSupport() bool + HasThinkingSupport() bool +} + +func ParserForName(name string) BuiltinParser { + switch name { + case "qwen3-coder": + parser := &Qwen3CoderParser{} + return parser + case "passthrough": + return &PassthroughParser{} + default: + return nil + } +} + +type PassthroughParser struct{} + +func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { + return s, "", nil, nil +} + +func (p *PassthroughParser) HasToolSupport() bool { + return false +} + +func (p *PassthroughParser) HasThinkingSupport() bool { + return false +} diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go new file mode 100644 index 000000000..b0e8ec48c --- /dev/null +++ b/model/parsers/qwen3coder.go @@ -0,0 +1,410 @@ +package parsers + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "log/slog" + "math" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type qwenParserState int + +const ( + toolOpenTag = "" + toolCloseTag = "" +) + +const ( + qwenParserState_LookingForToolStart qwenParserState = iota + qwenParserState_CollectingToolContent +) + +type Qwen3CoderParser struct { + state qwenParserState + acc strings.Builder +} + +func (p *Qwen3CoderParser) HasToolSupport() bool { + return true +} + +func (p *Qwen3CoderParser) HasThinkingSupport() bool { + return false +} + +func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { + p.acc.WriteString(s) + + events := p.parseEvents() + + var toolCalls []api.ToolCall + var sb strings.Builder + for _, event := range events { + switch event := event.(type) { + case qwenEventRawToolCall: + toolCall, err := parseToolCall(event, tools) + if err != nil { + slog.Warn("qwen tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case qwenEventContent: + // TODO(drifkin): if the same turn contains multiple interleaved content + // events, we naively append them together here. See the note below about + // `qwenEvent`s for more details + sb.WriteString(event.content) + } + } + + return sb.String(), "", toolCalls, nil +} + +func (p *Qwen3CoderParser) parseEvents() []qwenEvent { + var all []qwenEvent + + keepLooping := true + for keepLooping { + var events []qwenEvent + events, keepLooping = eat(p) + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String()) + } + + return all +} + +// we use some internal event types in order to communicate between `Add` and +// `eat`. We do this to support interleaving content and parallel tool calls in +// the parser, even though qwen3-coder isn't supposed to do this. Our API +// doesn't currently support models outputting multiple messages in a turn, so +// we wouldn't be able to represent it yet, but there's no reason to prevent the +// parser from supporting it, especially for future models if they end up using +// a similar format. +type qwenEvent interface { + isQwenEvent() +} + +type qwenEventRawToolCall struct { + raw string +} + +type qwenEventContent struct { + content string +} + +func (qwenEventContent) isQwenEvent() {} +func (qwenEventRawToolCall) isQwenEvent() {} + +// eat consumes the parser's buffer, and returns a list of any unambiguous +// events from the current parser state. If the parser transitions to another +// state, it may have additional events to emit on the next call, which is what +// the second return value indicates +func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) { + var events []qwenEvent + + switch p.state { + case qwenParserState_LookingForToolStart: + if strings.Contains(p.acc.String(), toolOpenTag) { + // we found a full tool open tag, so we can emit the content before the + // tag, being sure to trim any trailing whitespace + split := strings.SplitN(p.acc.String(), toolOpenTag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, qwenEventContent{content: before}) + } + after := split[1] + p.acc.Reset() + p.acc.WriteString(after) + p.state = qwenParserState_CollectingToolContent + return events, true + } else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 { + // we found a partial tool open tag, so we can emit the unambiguous part, + // which is the (trailing-whitespace trimmed) content before the partial + // tool open tag + beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + events = append(events, qwenEventContent{content: unambiguous}) + return events, false + } else { + // we found content that is entirely not a tool call. We should withhold + // any trailing whitespace in case this is the end of the content + whitespaceLen := trailingWhitespaceLen(p.acc.String()) + ambiguousStart := len(p.acc.String()) - whitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } + return events, false + } + case qwenParserState_CollectingToolContent: + if strings.Contains(p.acc.String(), toolCloseTag) { + split := strings.SplitN(p.acc.String(), toolCloseTag, 2) + before := split[0] + if len(before) == 0 { + slog.Warn("qwen tool call closing tag found but no content before it") + } + // remove any whitespace between the tool call and any content after it + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.acc.Reset() + p.acc.WriteString(after) + events = append(events, qwenEventRawToolCall{raw: before}) + p.state = qwenParserState_LookingForToolStart + return events, true + } else { + // note that we don't need to check the overlap here because we only plan + // on parsing the tool call once we see the full closing tag. We don't + // stream back the unparsed tool content, so there's no need to be eager + // here + return events, false + } + default: + panic("unreachable") + } +} + +// TODO(drifkin): move this to a shared location +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +func trailingWhitespaceLen(s string) int { + for i := len(s) - 1; i >= 0; i-- { + if !unicode.IsSpace(rune(s[i])) { + return len(s) - i - 1 + } + } + return len(s) +} + +type XMLFunctionCall struct { + XMLName xml.Name `xml:"function"` + Name string `xml:"name,attr"` + Parameters []XMLParameter `xml:"parameter"` +} + +type XMLParameter struct { + Name string `xml:"name,attr"` + Value string `xml:",chardata"` +} + +// parseToolCall parses a raw tool call string into an api.ToolCall. +// The raw string follows an xml-like format, here's an example: +// +// +// +// San Francisco +// +// +// celsius +// +// +func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + xmlString := transformToXML(raw.raw) + + var functionCall XMLFunctionCall + err := xml.Unmarshal([]byte(xmlString), &functionCall) + if err != nil { + return api.ToolCall{}, err + } + + toolCall.Function = api.ToolCallFunction{ + Name: functionCall.Name, + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionCall.Name { + matchedTool = &tools[i] + break + } + } + + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + for _, parameter := range functionCall.Parameters { + // Look up the parameter type if we found the tool + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok { + paramType = prop.Type + } + } + + toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType) + } + + return toolCall, nil +} + +// parseValue converts a raw string value to the appropriate type based on the parameter type specification. +// +// For union types (multiple types in PropertyType, which we support but doesn't +// seem as though the reference parser does type coercion with those types in +// mind) we use a type precedence approach: +// 1. null - checked first regardless of declared types (matches reference implementation) +// 2. boolean - only "true"/"false" are valid booleans +// 3. integer - must parse as a whole number +// 4. number - must parse as numeric (returns int if no decimal part) +// 5. array - must parse as valid JSON array +// 6. object - must parse as valid JSON object +// 7. string - always succeeds (least specific type) +// +// This precedence ensures we return the most specific type that successfully parses, +// following the principle of least surprise. For example, with PropertyType{"string", "number"}, +// "123" becomes 123 (number), while "hello" becomes "hello" (string). +func parseValue(raw string, paramType api.PropertyType) any { + // first remove a single leading newlines, and a single trailing newline (if + // they exist). This follows the reference implementation + raw = strings.TrimPrefix(raw, "\n") + raw = strings.TrimSuffix(raw, "\n") + + // Check for null first (case-insensitive) - this takes precedence over any type + if strings.ToLower(raw) == "null" { + return nil + } + + // If no type is specified, default to string + if len(paramType) == 0 { + return raw + } + + // Check if any of the specified types match, using type precedence + // Order: boolean -> integer -> number -> array -> object -> string + typeSet := make(map[string]bool) + for _, t := range paramType { + typeSet[t] = true + } + + // Try boolean first (most restrictive) + if typeSet["boolean"] { + lower := strings.ToLower(raw) + switch lower { + case "true": + return true + case "false": + return false + } + // If not a valid boolean but boolean is the only type, return false (matching reference) + if len(paramType) == 1 { + return false + } + // Otherwise try other types + } + + // Try integer + if typeSet["integer"] { + if i, err := strconv.ParseInt(raw, 10, 64); err == nil { + // Return as int if it fits in int32, otherwise int64 + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + // If integer is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try number (float) + if typeSet["number"] { + if f, err := strconv.ParseFloat(raw, 64); err == nil { + // If the number has no decimal part, return as int (matching reference) + if f == math.Trunc(f) { + i := int64(f) + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + return f + } + // If number is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try array + if typeSet["array"] { + var arr []interface{} + if err := json.Unmarshal([]byte(raw), &arr); err == nil { + return arr + } + // If array is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try object + if typeSet["object"] { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(raw), &obj); err == nil { + return obj + } + // If object is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // String always succeeds (or if "string" is in the type set) + if typeSet["string"] { + return raw + } + + // If we get here, none of the types matched and string wasn't an option + // We return string as a fallback. The reference implementation will attempt + // to parse the value as a python literal, but we purposefully don't support + // that + return raw +} + +var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) + +// transformToXML transforms a raw qwen tool call with xml-like tags into valid +// xml so that it can be parsed by any xml parser +func transformToXML(raw string) string { + // take the form `` and transform it to ``, taking + // care to properly escape the string that becomes the attribute value + return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { + groups := qwenTagRegex.FindStringSubmatch(match) + tag := groups[1] + var escapedValue strings.Builder + xml.EscapeText(&escapedValue, []byte(groups[2])) + return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) + }) +} diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go new file mode 100644 index 000000000..c0dad28d1 --- /dev/null +++ b/model/parsers/qwen3coder_test.go @@ -0,0 +1,830 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +// tool creates a test tool with the given name and properties +func tool(name string, props map[string]api.ToolProperty) api.Tool { + t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}} + t.Function.Parameters.Type = "object" + t.Function.Parameters.Properties = props + return t +} + +func TestQwenParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple message streamed word by word", + steps: []step{ + { + input: "hi", + wantEvents: []qwenEvent{qwenEventContent{content: "hi"}}, + }, + { + input: " there", + wantEvents: []qwenEvent{qwenEventContent{content: " there"}}, + }, + }, + }, + { + desc: "content before tool call", + steps: []step{ + { + input: "hi there", + wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}}, + }, + }, + }, + { + desc: "multiple tool calls in one message", + steps: []step{ + { + input: "before1in tool callafter1in tool call 2after2", + wantEvents: []qwenEvent{ + qwenEventContent{content: "before1"}, + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "after1"}, + qwenEventRawToolCall{raw: "in tool call 2"}, + qwenEventContent{content: "after2"}, + }, + }, + }, + }, + { + desc: "tool calls with split tags", + steps: []step{ + { + input: "beforein tool callaf", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "af"}, + }, + }, + { + input: "ter", + wantEvents: []qwenEvent{ + qwenEventContent{content: "ter"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between content and tool call", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventContent{content: "abc"}, + qwenEventRawToolCall{raw: "def"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between tool call and content", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "empty content before tool call", + steps: []step{ + { + input: "\nabc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, + { + desc: "partial tool open tag fakeout", + steps: []step{ + { + input: "abc\n + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_temperature", + Arguments: map[string]any{ + "location": "San Francisco", + "unit": "celsius", + }, + }, + }, + }, + { + name: "names with spaces", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get current temperature", + Arguments: map[string]any{ + "location with spaces": "San Francisco", + "unit with spaces": "celsius", + }, + }, + }, + }, + // this mirrors the reference implementation's behavior, but unclear if it + // ever happens. If so, then we should probably remove them instead, this + // test is to just document the current behavior and test that we don't get + // xml errors + { + name: "names with quotes", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +"celsius" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "\"get current temperature\"", + Arguments: map[string]any{ + "\"location with spaces\"": "San Francisco", + "\"unit with spaces\"": "\"celsius\"", + }, + }, + }, + }, + { + name: "tool call with typed parameters", + tools: []api.Tool{ + tool("calculate", map[string]api.ToolProperty{ + "x": {Type: api.PropertyType{"number"}}, + "y": {Type: api.PropertyType{"integer"}}, + "enabled": {Type: api.PropertyType{"boolean"}}, + "items": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: ` + +3.14 + + +42 + + +true + + +["a", "b", "c"] + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: map[string]any{ + "x": 3.14, + "y": 42, + "enabled": true, + "items": []interface{}{"a", "b", "c"}, + }, + }, + }, + }, + } + + for i, step := range steps { + gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools) + if err != nil { + t.Errorf("step %d (%s): %v", i, step.name, err) + } + if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) + } + } +} + +func TestQwenToolCallValueParsing(t *testing.T) { + cases := []struct { + desc string + raw string + paramType api.PropertyType + want any + }{ + { + desc: "default string value (no type specified)", + paramType: api.PropertyType{}, + raw: "some-string", + want: "some-string", + }, + { + desc: "trim a single leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\nsome-string\n", + want: "some-string", + }, + { + desc: "trim at most one leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\n\nsome-string\n\n", + want: "\nsome-string\n", + }, + { + desc: "newline really has to be the first character to be trimmed", + paramType: api.PropertyType{}, + raw: " \nsome-string\n ", + want: " \nsome-string\n ", + }, + { + desc: "numeric type", + paramType: api.PropertyType{"number"}, + raw: "123", + want: 123, + }, + // Integer parsing tests + { + desc: "integer type", + paramType: api.PropertyType{"integer"}, + raw: "42", + want: 42, + }, + { + desc: "negative integer", + paramType: api.PropertyType{"integer"}, + raw: "-100", + want: -100, + }, + { + desc: "zero integer", + paramType: api.PropertyType{"integer"}, + raw: "0", + want: 0, + }, + { + desc: "integer with leading zeros", + paramType: api.PropertyType{"integer"}, + raw: "007", + want: 7, + }, + { + desc: "large integer", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Just beyond int32 max + want: int64(2147483648), + }, + // Float/number parsing tests + { + desc: "float type", + paramType: api.PropertyType{"number"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "negative float", + paramType: api.PropertyType{"number"}, + raw: "-273.15", + want: -273.15, + }, + { + desc: "float without decimal part", + paramType: api.PropertyType{"number"}, + raw: "100.0", + want: 100, + }, + { + desc: "scientific notation positive", + paramType: api.PropertyType{"number"}, + raw: "1.23e5", + want: 123000, // Will be int since it has no decimal part + }, + { + desc: "scientific notation negative", + paramType: api.PropertyType{"number"}, + raw: "1.5e-3", + want: 0.0015, + }, + { + desc: "very small float", + paramType: api.PropertyType{"number"}, + raw: "0.00000001", + want: 0.00000001, + }, + // String parsing tests + { + desc: "explicit string type", + paramType: api.PropertyType{"string"}, + raw: "hello world", + want: "hello world", + }, + { + desc: "string with special characters", + paramType: api.PropertyType{"string"}, + raw: "/usr/local/bin/test-file_v2.0.sh", + want: "/usr/local/bin/test-file_v2.0.sh", + }, + { + desc: "string with quotes", + paramType: api.PropertyType{"string"}, + raw: `He said "hello" to me`, + want: `He said "hello" to me`, + }, + { + desc: "multiline string", + paramType: api.PropertyType{"string"}, + raw: "line one\nline two\nline three", + want: "line one\nline two\nline three", + }, + { + desc: "empty string", + paramType: api.PropertyType{"string"}, + raw: "", + want: "", + }, + { + desc: "string that looks like a number", + paramType: api.PropertyType{"string"}, + raw: "12345", + want: "12345", + }, + // Boolean parsing tests + { + desc: "boolean true", + paramType: api.PropertyType{"boolean"}, + raw: "true", + want: true, + }, + { + desc: "boolean false", + paramType: api.PropertyType{"boolean"}, + raw: "false", + want: false, + }, + { + desc: "boolean case insensitive true", + paramType: api.PropertyType{"boolean"}, + raw: "True", + want: true, + }, + { + desc: "boolean case insensitive false", + paramType: api.PropertyType{"boolean"}, + raw: "FALSE", + want: false, + }, + // Null parsing tests + { + desc: "null value lowercase", + paramType: api.PropertyType{"string"}, + raw: "null", + want: nil, + }, + { + desc: "null value case insensitive", + paramType: api.PropertyType{"integer"}, + raw: "NULL", + want: nil, + }, + // Array parsing tests + { + desc: "array of strings", + paramType: api.PropertyType{"array"}, + raw: `["foo", "bar", "baz"]`, + want: []interface{}{"foo", "bar", "baz"}, + }, + { + desc: "array of numbers", + paramType: api.PropertyType{"array"}, + raw: `[1, 2.5, 3]`, + want: []interface{}{float64(1), 2.5, float64(3)}, + }, + { + desc: "array of mixed types", + paramType: api.PropertyType{"array"}, + raw: `["string", 123, true, null]`, + want: []interface{}{"string", float64(123), true, nil}, + }, + { + desc: "empty array", + paramType: api.PropertyType{"array"}, + raw: `[]`, + want: []interface{}{}, + }, + // Object parsing tests + { + desc: "simple object", + paramType: api.PropertyType{"object"}, + raw: `{"key": "value", "number": 42}`, + want: map[string]interface{}{"key": "value", "number": float64(42)}, + }, + { + desc: "nested object", + paramType: api.PropertyType{"object"}, + raw: `{"outer": {"inner": "value"}}`, + want: map[string]interface{}{"outer": map[string]interface{}{"inner": "value"}}, + }, + { + desc: "empty object", + paramType: api.PropertyType{"object"}, + raw: `{}`, + want: map[string]interface{}{}, + }, + // Error cases and fallback behavior + { + desc: "invalid integer falls back to string", + paramType: api.PropertyType{"integer"}, + raw: "not-a-number", + want: "not-a-number", + }, + { + desc: "invalid float falls back to string", + paramType: api.PropertyType{"number"}, + raw: "3.14.159", + want: "3.14.159", + }, + { + desc: "invalid boolean falls back to false", + paramType: api.PropertyType{"boolean"}, + raw: "yes", + want: false, + }, + { + desc: "invalid JSON array falls back to string", + paramType: api.PropertyType{"array"}, + raw: "[1, 2, unclosed", + want: "[1, 2, unclosed", + }, + { + desc: "invalid JSON object falls back to string", + paramType: api.PropertyType{"object"}, + raw: `{"key": unclosed`, + want: `{"key": unclosed`, + }, + // Edge cases + { + desc: "integer overflow should use int64", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Beyond int32 max + want: int64(2147483648), + }, + { + desc: "float with many decimal places", + paramType: api.PropertyType{"number"}, + raw: "3.141592653589793", + want: 3.141592653589793, + }, + { + desc: "string with JSON-like content", + paramType: api.PropertyType{"string"}, + raw: `{"this": "is", "just": "a string"}`, + want: `{"this": "is", "just": "a string"}`, + }, + { + desc: "whitespace-only string", + paramType: api.PropertyType{"string"}, + raw: " ", + want: " ", + }, + // Unknown parameter (no type specified in tools) + { + desc: "parameter not in tool definition defaults to string", + paramType: api.PropertyType{}, + raw: "some value", + want: "some value", + }, + // Union type tests + { + desc: "string or number union - valid number", + paramType: api.PropertyType{"string", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "string or number union - non-numeric string", + paramType: api.PropertyType{"string", "number"}, + raw: "hello", + want: "hello", + }, + { + desc: "number or string union - valid number (order shouldn't matter)", + paramType: api.PropertyType{"number", "string"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "integer or null union - valid integer", + paramType: api.PropertyType{"integer", "null"}, + raw: "123", + want: 123, + }, + { + desc: "integer or null union - null value", + paramType: api.PropertyType{"integer", "null"}, + raw: "null", + want: nil, + }, + { + desc: "null or integer union - null value (order shouldn't matter)", + paramType: api.PropertyType{"null", "integer"}, + raw: "null", + want: nil, + }, + { + desc: "boolean or string union - valid boolean", + paramType: api.PropertyType{"boolean", "string"}, + raw: "true", + want: true, + }, + { + desc: "boolean or string union - non-boolean becomes string", + paramType: api.PropertyType{"boolean", "string"}, + raw: "yes", + want: "yes", + }, + { + desc: "string or boolean union - valid boolean (precedence test)", + paramType: api.PropertyType{"string", "boolean"}, + raw: "false", + want: false, // Should be boolean, not string "false" + }, + { + desc: "integer or number union - integer value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42", + want: 42, + }, + { + desc: "integer or number union - float value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "number or integer union - integer value (precedence test)", + paramType: api.PropertyType{"number", "integer"}, + raw: "42", + want: 42, // Should try integer first due to precedence + }, + { + desc: "array or object union - valid array", + paramType: api.PropertyType{"array", "object"}, + raw: `[1, 2, 3]`, + want: []interface{}{float64(1), float64(2), float64(3)}, + }, + { + desc: "array or object union - valid object", + paramType: api.PropertyType{"array", "object"}, + raw: `{"key": "value"}`, + want: map[string]interface{}{"key": "value"}, + }, + { + desc: "object or array union - valid array (precedence test)", + paramType: api.PropertyType{"object", "array"}, + raw: `[1, 2, 3]`, + want: []interface{}{float64(1), float64(2), float64(3)}, + }, + { + desc: "complex multi-type union - null", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "null", + want: nil, + }, + { + desc: "complex multi-type union - boolean", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "true", + want: true, + }, + { + desc: "complex multi-type union - number", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "complex multi-type union - string", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "hello", + want: "hello", + }, + { + desc: "integer string union - integer string becomes integer", + paramType: api.PropertyType{"integer", "string"}, + raw: "123", + want: 123, + }, + { + desc: "string integer union - integer string becomes integer (precedence)", + paramType: api.PropertyType{"string", "integer"}, + raw: "123", + want: 123, // Integer has higher precedence than string + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := parseValue(tc.raw, tc.paramType) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want) + } + }) + } +} + +func TestQwenXMLTransform(t *testing.T) { + cases := []struct { + desc string + raw string + want string + }{ + { + desc: "simple example", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + // even though quotes aren't expected in these tags, we have these tests to + // make sure they're escaped so they don't blow up the xml parser in case + // they happen + { + desc: "names with quotes", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + } + + for _, tc := range cases { + got := transformToXML(tc.raw) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) + } + } +} + +func TestTrailingWhitespaceLen(t *testing.T) { + cases := []struct { + desc string + s string + want int + }{ + {desc: "no whitespace", s: "abc", want: 0}, + {desc: "trailing whitespace", s: "abc ", want: 1}, + {desc: "trailing whitespace with newlines", s: "abc \n", want: 2}, + {desc: "only whitespace", s: " \n ", want: 4}, + {desc: "leading whitespace doesn't count", s: " \n abc", want: 0}, + } + + for _, tc := range cases { + got := trailingWhitespaceLen(tc.s) + if got != tc.want { + t.Errorf("got %d, want %d", got, tc.want) + } + } +} diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go new file mode 100644 index 000000000..074def0eb --- /dev/null +++ b/model/renderers/qwen3coder.go @@ -0,0 +1,217 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/ollama/ollama/api" +) + +var ( + imStartTag = "<|im_start|>" + imEndTag = "<|im_end|>" +) + +// renderAdditionalKeys renders all JSON fields except the ones in handledKeys +// This follows the same approach from the reference implementation, which gives +// a particular key ordering +func renderAdditionalKeys(obj interface{}, handledKeys map[string]bool) string { + data, err := json.Marshal(obj) + if err != nil { + return "" + } + + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + return "" + } + + var sb strings.Builder + for key, value := range m { + if handledKeys[key] { + continue + } + + // Check if value is a map or array (needs JSON serialization) + switch v := value.(type) { + case map[string]interface{}, []interface{}: + jsonBytes, _ := json.Marshal(v) + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonStr := string(jsonBytes) + sb.WriteString("\n<" + key + ">" + jsonStr + "") + case nil: + continue + default: + // Simple types, convert to string + sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "") + } + } + + return sb.String() +} + +func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + // filter out system messages and choose the first (if any) to win + var systemMessage string + var filteredMessages []api.Message + for _, message := range messages { + if message.Role != "system" { + filteredMessages = append(filteredMessages, message) + continue + } + + if systemMessage == "" { + systemMessage = message.Content + } + } + + if systemMessage != "" || len(tools) > 0 { + sb.WriteString(imStartTag + "system\n") + + // if we have tools but no system message, match the reference implementation by providing a default system message + if systemMessage == "" { + systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." + } + + sb.WriteString(systemMessage) + + if len(tools) > 0 { + sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n") + sb.WriteString("") + for _, tool := range tools { + sb.WriteString("\n") + sb.WriteString("\n") + sb.WriteString("" + tool.Function.Name + "") + if tool.Function.Description != "" { + sb.WriteString("\n" + tool.Function.Description + "") + } + sb.WriteString("\n") + + for name, prop := range tool.Function.Parameters.Properties { + sb.WriteString("\n") + sb.WriteString("\n" + name + "") + + if len(prop.Type) > 0 { + // TODO(!!!)(drifkin): we should match the reference implementation for + // more complex types here instead of using this format + sb.WriteString("\n" + prop.ToTypeScriptType() + "") + } + + if prop.Description != "" { + sb.WriteString("\n" + prop.Description + "") + } + + // Render any additional keys not already handled + handledKeys := map[string]bool{ + "type": true, + "description": true, + } + sb.WriteString(renderAdditionalKeys(prop, handledKeys)) + + sb.WriteString("\n") + } + + // Render extra keys for parameters (everything except 'type' and 'properties') + paramHandledKeys := map[string]bool{ + "type": true, + "properties": true, + } + sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys)) + + sb.WriteString("\n") + sb.WriteString("\n") + } + sb.WriteString("\n") + sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n") + } + + sb.WriteString(imEndTag + "\n") + } + + for i, message := range filteredMessages { + lastMessage := i == len(filteredMessages)-1 + prefill := lastMessage && message.Role == "assistant" + switch message.Role { + case "assistant": + if len(message.ToolCalls) > 0 { + sb.WriteString(imStartTag + "assistant\n") + if message.Content != "" { + sb.WriteString(message.Content + "\n") + } + for _, toolCall := range message.ToolCalls { + sb.WriteString("\n\n") + for name, value := range toolCall.Function.Arguments { + valueStr := formatToolCallArgument(value) + sb.WriteString("\n\n" + valueStr + "\n") + } + sb.WriteString("\n\n") + } + sb.WriteString("<|im_end|>\n") + } else { + sb.WriteString(imStartTag + "assistant\n") + sb.WriteString(message.Content) + if !prefill { + sb.WriteString(imEndTag + "\n") + } + } + case "tool": + // consecutive tool responses should share a single `user`, but + // have their own tags + + // only start a new user block if this is the first tool response + if i == 0 || filteredMessages[i-1].Role != "tool" { + sb.WriteString(imStartTag + "user\n") + } + + sb.WriteString("\n") + sb.WriteString(message.Content) + sb.WriteString("\n\n") + + // close the user block only if this is the last tool response + if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" { + sb.WriteString(imEndTag + "\n") + } + default: + sb.WriteString(imStartTag + message.Role + "\n") + sb.WriteString(message.Content) + sb.WriteString(imEndTag + "\n") + } + + if lastMessage && !prefill { + sb.WriteString(imStartTag + "assistant\n") + } + } + + return sb.String(), nil +} + +func formatToolCallArgument(value any) string { + if value == nil { + return "null" + } + + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + } + + if reflect.TypeOf(value) != nil { + kind := reflect.TypeOf(value).Kind() + if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array { + if marshalled, err := json.Marshal(value); err == nil { + return string(marshalled) + } + } + } + + return fmt.Sprintf("%v", value) +} diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go new file mode 100644 index 000000000..4aaa066d6 --- /dev/null +++ b/model/renderers/qwen3coder_test.go @@ -0,0 +1,338 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestQwen3CoderRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello, how are you?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "with tools and response", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in San Francisco?"}, + { + Role: "assistant", + Content: "I'll check the weather in San Francisco for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{ + "unit": "fahrenheit", + }, + }, + }, + }, + }, + {Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"}, + {Role: "user", Content: "That sounds nice! What about New York?"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Required: []string{"unit"}, + Properties: map[string]api.ToolProperty{ + "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"}, + // TODO(drifkin): add multiple params back once we have predictable + // order via some sort of ordered map type (see + // ) + /* + "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"}, + */ + }, + }, + }}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +get_weather +Get the current weather in a given location + + +unit +string +The unit of temperature +["celsius","fahrenheit"] + +["unit"] + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +What is the weather like in San Francisco?<|im_end|> +<|im_start|>assistant +I'll check the weather in San Francisco for you. + + + + +fahrenheit + + +<|im_end|> +<|im_start|>user + +{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12} + +<|im_end|> +<|im_start|>user +That sounds nice! What about New York?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "call double(1) and triple(2)"}, + {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}}, + {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}}, + }}, + {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"}, + {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to double"}, + }}}}, + {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"}, + }}}}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +double +Double a number + + +number +string +The number to double + + + + +triple +Triple a number + + +number +string +The number to triple + + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +call double(1) and triple(2)<|im_end|> +<|im_start|>assistant +I'll call double(1) and triple(2) for you. + + + + +1 + + + + + + +2 + + +<|im_end|> +<|im_start|>user + +{"number": 2} + + +{"number": 6} + +<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "prefill", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Tell me something interesting."}, + {Role: "assistant", Content: "I'll tell you something interesting about cats"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Tell me something interesting.<|im_end|> +<|im_start|>assistant +I'll tell you something interesting about cats`, + }, + { + name: "complex tool call arguments should remain json encoded", + msgs: []api.Message{ + {Role: "user", Content: "call tool"}, + {Role: "assistant", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "echo", + Arguments: map[string]any{ + "payload": map[string]any{"foo": "bar"}, + }, + }}, + }}, + {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"}, + }, + expected: `<|im_start|>user +call tool<|im_end|> +<|im_start|>assistant + + + + +{"foo":"bar"} + + +<|im_end|> +<|im_start|>user + +{"payload": {"foo": "bar"}} + +<|im_end|> +<|im_start|>assistant +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestFormatToolCallArgument(t *testing.T) { + tests := []struct { + name string + arg any + expected string + }{ + { + name: "string", + arg: "foo", + // notice no quotes around the string + expected: "foo", + }, + { + name: "map", + arg: map[string]any{"foo": "bar"}, + expected: "{\"foo\":\"bar\"}", + }, + { + name: "number", + arg: 1, + expected: "1", + }, + { + name: "boolean", + arg: true, + expected: "true", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolCallArgument(tt.arg) + if got != tt.expected { + t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go new file mode 100644 index 000000000..2dfb51e49 --- /dev/null +++ b/model/renderers/renderer.go @@ -0,0 +1,26 @@ +package renderers + +import ( + "fmt" + + "github.com/ollama/ollama/api" +) + +type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error) + +func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + renderer := rendererForName(name) + if renderer == nil { + return "", fmt.Errorf("unknown renderer %q", name) + } + return renderer(msgs, tools, think) +} + +func rendererForName(name string) rendererFunc { + switch name { + case "qwen3-coder": + return Qwen3CoderRenderer + default: + return nil + } +} diff --git a/openai/openai.go b/openai/openai.go index b6a8a95e2..7ef5ac6de 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -105,16 +105,18 @@ type ChatCompletionRequest struct { Tools []api.Tool `json:"tools"` Reasoning *Reasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` + DebugRenderOnly bool `json:"_debug_render_only"` } type ChatCompletion struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage,omitempty"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage,omitempty"` + DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"` } type ChatCompletionChunk struct { @@ -141,6 +143,7 @@ type CompletionRequest struct { Temperature *float32 `json:"temperature"` TopP float32 `json:"top_p"` Suffix string `json:"suffix"` + DebugRenderOnly bool `json:"_debug_render_only"` } type Completion struct { @@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), - }}, - Usage: toUsage(r), + }}, Usage: toUsage(r), + DebugInfo: r.DebugInfo, } } @@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } return &api.ChatRequest{ - Model: r.Model, - Messages: messages, - Format: format, - Options: options, - Stream: &r.Stream, - Tools: r.Tools, - Think: think, + Model: r.Model, + Messages: messages, + Format: format, + Options: options, + Stream: &r.Stream, + Tools: r.Tools, + Think: think, + DebugRenderOnly: r.DebugRenderOnly, }, nil } @@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { } return api.GenerateRequest{ - Model: r.Model, - Prompt: r.Prompt, - Options: options, - Stream: &r.Stream, - Suffix: r.Suffix, + Model: r.Model, + Prompt: r.Prompt, + Options: options, + Stream: &r.Stream, + Suffix: r.Suffix, + DebugRenderOnly: r.DebugRenderOnly, }, nil } diff --git a/parser/parser.go b/parser/parser.go index e080f1bb7..c2e8f981f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) req.System = c.Args case "license": licenses = append(licenses, c.Args) + case "renderer": + req.Renderer = c.Args + case "parser": + req.Parser = c.Args case "message": role, msg, _ := strings.Cut(c.Args, ": ") messages = append(messages, api.Message{Role: role, Content: msg}) @@ -320,7 +324,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter": + case "license", "template", "system", "adapter", "renderer", "parser": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -346,7 +350,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"") ) type ParserError struct { @@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "parameter", "message": + case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message": return true default: return false diff --git a/parser/parser_test.go b/parser/parser_test.go index 7d5a808ba..1524e890a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -198,6 +198,34 @@ BADCOMMAND param1 value1 } } +func TestParseFileRenderer(t *testing.T) { + input := ` +FROM foo +RENDERER renderer1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands) +} + +func TestParseFileParser(t *testing.T) { + input := ` +FROM foo +PARSER parser1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands) +} + func TestParseFileMessages(t *testing.T) { cases := []struct { input string diff --git a/server/create.go b/server/create.go index bd970876f..f08f18b34 100644 --- a/server/create.go +++ b/server/create.go @@ -323,6 +323,8 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, RootFS: RootFS{ Type: "layers", }, + Renderer: r.Renderer, + Parser: r.Parser, } var layers []Layer diff --git a/server/images.go b/server/images.go index 504eb95cf..6432860f8 100644 --- a/server/images.go +++ b/server/images.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/gguf" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" @@ -94,8 +95,9 @@ func (m *Model) Capabilities() []model.Capability { return capabilities } + builtinParser := parsers.ParserForName(m.Config.Parser) // Check for tools capability - if slices.Contains(m.Template.Vars(), "tools") { + if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { capabilities = append(capabilities, model.CapabilityTools) } @@ -112,7 +114,8 @@ func (m *Model) Capabilities() []model.Capability { // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) hasTags := openingTag != "" && closingTag != "" - if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) { + isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) + if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) { capabilities = append(capabilities, model.CapabilityThinking) } @@ -198,6 +201,20 @@ func (m *Model) String() string { }) } + if m.Config.Renderer != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "renderer", + Args: m.Config.Renderer, + }) + } + + if m.Config.Parser != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "parser", + Args: m.Config.Parser, + }) + } + for k, v := range m.Options { switch v := v.(type) { case []any: @@ -238,6 +255,8 @@ type ConfigV2 struct { ModelFamilies []string `json:"model_families"` ModelType string `json:"model_type"` FileType string `json:"file_type"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` // required by spec Architecture string `json:"architecture"` diff --git a/server/prompt.go b/server/prompt.go index f1d8020ea..56bc63030 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/model/renderers" "github.com/ollama/ollama/template" ) @@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - thinkVal := false - thinkLevel := "" - if think != nil { - thinkVal = think.Bool() - thinkLevel = think.String() - } - var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) + if err != nil { return "", nil, err } - s, err := tokenize(ctx, b.String()) + s, err := tokenize(ctx, p) if err != nil { return "", nil, err } @@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } // truncate any messages that do not fit into the context window + p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think) + if err != nil { + return "", nil, err + } + + return p, images, nil +} + +func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + if m.Config.Renderer != "" { + rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) + if err != nil { + return "", err + } + return rendered, nil + } + var b bytes.Buffer thinkVal := false thinkLevel := "" @@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. thinkVal = think.Bool() thinkLevel = think.String() } - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { - return "", nil, err + if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + return "", err } - - return b.String(), images, nil + return b.String(), nil } diff --git a/server/routes.go b/server/routes.go index 5114cb74f..849d2ede8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -35,6 +35,7 @@ import ( "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" @@ -329,10 +330,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -1617,10 +1618,15 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) + var builtinParser parsers.BuiltinParser + if m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + } + var harmonyMessageHandler *harmony.HarmonyMessageHandler var harmonyToolParser *harmony.HarmonyToolCallAccumulator - useHarmony := shouldUseHarmony(m) + useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony" processedTools := req.Tools if useHarmony { @@ -1650,10 +1656,10 @@ func (s *Server) ChatHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -1713,6 +1719,7 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } + // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic if useHarmony { content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) res.Message.Content = content @@ -1739,6 +1746,27 @@ func (s *Server) ChatHandler(c *gin.Context) { ch <- res } + return + } else if builtinParser != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) + + content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + + res.Message.Content = content + res.Message.Thinking = thinking + res.Message.ToolCalls = toolCalls + + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) + ch <- res + } else { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) + } + return } diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index f04a1da99..6507284ef 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.GenerateResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } @@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } From 472feec2ff5096eb23f72356f26d67b71f18d01e Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Mon, 15 Sep 2025 11:46:25 -0700 Subject: [PATCH 18/32] address comments --- model/parsers/parsers.go | 4 ++-- model/parsers/qwen3coder_test.go | 22 +++++++++++----------- model/renderers/qwen3coder.go | 6 +++--- server/routes.go | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 001cac442..e6dbd1f4f 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -4,13 +4,13 @@ import ( "github.com/ollama/ollama/api" ) -type BuiltinParser interface { +type Parser interface { Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) HasToolSupport() bool HasThinkingSupport() bool } -func ParserForName(name string) BuiltinParser { +func ParserForName(name string) Parser { switch name { case "qwen3-coder": parser := &Qwen3CoderParser{} diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go index c0dad28d1..2389c77b5 100644 --- a/model/parsers/qwen3coder_test.go +++ b/model/parsers/qwen3coder_test.go @@ -307,7 +307,7 @@ true "x": 3.14, "y": 42, "enabled": true, - "items": []interface{}{"a", "b", "c"}, + "items": []any{"a", "b", "c"}, }, }, }, @@ -510,44 +510,44 @@ func TestQwenToolCallValueParsing(t *testing.T) { desc: "array of strings", paramType: api.PropertyType{"array"}, raw: `["foo", "bar", "baz"]`, - want: []interface{}{"foo", "bar", "baz"}, + want: []any{"foo", "bar", "baz"}, }, { desc: "array of numbers", paramType: api.PropertyType{"array"}, raw: `[1, 2.5, 3]`, - want: []interface{}{float64(1), 2.5, float64(3)}, + want: []any{float64(1), 2.5, float64(3)}, }, { desc: "array of mixed types", paramType: api.PropertyType{"array"}, raw: `["string", 123, true, null]`, - want: []interface{}{"string", float64(123), true, nil}, + want: []any{"string", float64(123), true, nil}, }, { desc: "empty array", paramType: api.PropertyType{"array"}, raw: `[]`, - want: []interface{}{}, + want: []any{}, }, // Object parsing tests { desc: "simple object", paramType: api.PropertyType{"object"}, raw: `{"key": "value", "number": 42}`, - want: map[string]interface{}{"key": "value", "number": float64(42)}, + want: map[string]any{"key": "value", "number": float64(42)}, }, { desc: "nested object", paramType: api.PropertyType{"object"}, raw: `{"outer": {"inner": "value"}}`, - want: map[string]interface{}{"outer": map[string]interface{}{"inner": "value"}}, + want: map[string]any{"outer": map[string]any{"inner": "value"}}, }, { desc: "empty object", paramType: api.PropertyType{"object"}, raw: `{}`, - want: map[string]interface{}{}, + want: map[string]any{}, }, // Error cases and fallback behavior { @@ -689,19 +689,19 @@ func TestQwenToolCallValueParsing(t *testing.T) { desc: "array or object union - valid array", paramType: api.PropertyType{"array", "object"}, raw: `[1, 2, 3]`, - want: []interface{}{float64(1), float64(2), float64(3)}, + want: []any{float64(1), float64(2), float64(3)}, }, { desc: "array or object union - valid object", paramType: api.PropertyType{"array", "object"}, raw: `{"key": "value"}`, - want: map[string]interface{}{"key": "value"}, + want: map[string]any{"key": "value"}, }, { desc: "object or array union - valid array (precedence test)", paramType: api.PropertyType{"object", "array"}, raw: `[1, 2, 3]`, - want: []interface{}{float64(1), float64(2), float64(3)}, + want: []any{float64(1), float64(2), float64(3)}, }, { desc: "complex multi-type union - null", diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go index 074def0eb..df3b3a45b 100644 --- a/model/renderers/qwen3coder.go +++ b/model/renderers/qwen3coder.go @@ -17,13 +17,13 @@ var ( // renderAdditionalKeys renders all JSON fields except the ones in handledKeys // This follows the same approach from the reference implementation, which gives // a particular key ordering -func renderAdditionalKeys(obj interface{}, handledKeys map[string]bool) string { +func renderAdditionalKeys(obj any, handledKeys map[string]bool) string { data, err := json.Marshal(obj) if err != nil { return "" } - var m map[string]interface{} + var m map[string]any if err := json.Unmarshal(data, &m); err != nil { return "" } @@ -36,7 +36,7 @@ func renderAdditionalKeys(obj interface{}, handledKeys map[string]bool) string { // Check if value is a map or array (needs JSON serialization) switch v := value.(type) { - case map[string]interface{}, []interface{}: + case map[string]any, []any: jsonBytes, _ := json.Marshal(v) // TODO(drifkin): it would be nice to format the JSON here similarly to // python's default json.dumps behavior (spaces after commas and colons). diff --git a/server/routes.go b/server/routes.go index 849d2ede8..a4fa00475 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1618,7 +1618,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var builtinParser parsers.BuiltinParser + var builtinParser parsers.Parser if m.Config.Parser != "" { builtinParser = parsers.ParserForName(m.Config.Parser) } From 6f7117145f56e00d16572ed11cb8f83c4b3af636 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 15 Sep 2025 14:33:06 -0700 Subject: [PATCH 19/32] batch: use tensors for outputs (#12185) this cleans up the model interface slightly without too much impact in other areas --- model/input/input.go | 14 +++++++------- model/models/gemma2/model.go | 3 +-- model/models/gemma3/embed.go | 1 - model/models/gemma3/model_text.go | 3 +-- model/models/gemma3n/model_text.go | 2 +- model/models/gptoss/model.go | 4 ++-- model/models/llama/model.go | 2 +- model/models/llama4/model.go | 4 +--- model/models/mistral3/model.go | 3 +-- model/models/mllama/model.go | 3 +-- model/models/qwen2/model.go | 2 +- model/models/qwen25vl/model.go | 3 +-- model/models/qwen3/model.go | 2 +- runner/ollamarunner/runner.go | 18 ++++++++---------- 14 files changed, 27 insertions(+), 37 deletions(-) diff --git a/model/input/input.go b/model/input/input.go index bd9b53ec6..35dc41b35 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -54,10 +54,9 @@ type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. Inputs ml.Tensor - // Multimodal is a set of multimodal embeddings previously created by - // EncodeMultimodal, along with an index into Inputs. Unused for text-only - // models or for batches without multimodal elements. - Multimodal []MultimodalIndex + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs ml.Tensor // Positions is the position for each Input, relative to its sequence. Equal // in length to Inputs. @@ -66,7 +65,8 @@ type Batch struct { // Sequences is the sequence for each Input. Equal in length to Inputs. Sequences []int - // Outputs are the set of indicies into Inputs for which output data should - // be returned. - Outputs []int32 + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. + Multimodal []MultimodalIndex } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index e621d03ae..84c89e1fe 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) @@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 16c299e22..7d1e269ff 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -22,7 +22,6 @@ type embedModel struct { } func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - batch.Outputs = batch.Positions // return all positions hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) switch m.PoolingType { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 2a3b23939..5e515a927 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) @@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index b75a2abb3..eeb9ab028 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) - hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) + hiddenStates = hiddenStates.Rows(ctx, batch.Outputs) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) return m.Output.Forward(ctx, hiddenStates), nil diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 3ef078095..a74f76487 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } var outputs ml.Tensor - if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if i == len(m.TransformerBlocks)-1 { + outputs = batch.Outputs } hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 77d8f36d3..51273a014 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 99a898d2d..9cb2efc87 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 408e54d3d..435b1a304 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index d0ad4670e..239d999d5 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c662f068..93a502612 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index d73f499d2..6c76305db 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) } func init() { diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 7a83e0d04..ec2adaa7d 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1081a1f55..3a32384f8 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Prepare the seqs and batch, but defer the input token values as we may not be ready yet var batchInputs []*input.Input + var batchOutputs []int32 var batch input.Batch resumeSeq := -1 @@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Sequences = append(batch.Sequences, seq.cache.Id) - seq.iBatch = len(batch.Outputs) - if i+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) + seq.iBatch = len(batchOutputs) + if i+1 == len(seq.inputs) || seq.embeddingOnly { + batchOutputs = append(batchOutputs, int32(len(batchInputs)-1)) } logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) seq.pendingInputs = append(seq.pendingInputs, inp) @@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) + batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) if err != nil { err = fmt.Errorf("failed to build graph: %w", err) @@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) { } // sample a token - vocabSize := len(outputs) / len(activeBatch.batch.Outputs) - logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) + vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0) + logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) @@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Positions[i] = int32(i) } - batch.Outputs = make([]int32, s.parallel) - for i := range batch.Outputs { - batch.Outputs[i] = int32(i) - } - batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) + batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) cache := s.model.Config().Cache if cache != nil { From 3f6642f6fcf971ed6da02aac30ecb68168556482 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 15 Sep 2025 15:35:59 -0700 Subject: [PATCH 20/32] model: implement bert in ollama engine (#9080) * fix truncate * s/SentencePieceModel/SentencePiece/ * bert * wordpiece * refactor pooling * more tokenizers * normalize embeddings --- convert/convert_bert.go | 8 +- ml/backend.go | 1 + ml/backend/ggml/ggml.go | 7 ++ ml/nn/pooling/pooling.go | 36 +++++++ model/model.go | 8 +- model/models/bert/model.go | 181 ++++++++++++++++++++++++++++++++++ model/models/gemma2/model.go | 4 +- model/models/gemma3/embed.go | 24 ++--- model/models/gemma3/model.go | 4 +- model/models/gemma3n/model.go | 4 +- model/models/models.go | 1 + model/sentencepiece.go | 16 +-- model/sentencepiece_test.go | 8 +- model/wordpiece.go | 167 +++++++++++++++++++++++++++++++ model/wordpiece_test.go | 51 ++++++++++ server/routes.go | 10 +- 16 files changed, 490 insertions(+), 40 deletions(-) create mode 100644 ml/nn/pooling/pooling.go create mode 100644 model/models/bert/model.go create mode 100644 model/wordpiece.go create mode 100644 model/wordpiece_test.go diff --git a/convert/convert_bert.go b/convert/convert_bert.go index a9f4b8a77..6b0d0030a 100644 --- a/convert/convert_bert.go +++ b/convert/convert_bert.go @@ -28,6 +28,7 @@ type bertModel struct { LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"` + normalizeEmbeddings bool PoolingType uint32 } @@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error { var pooling string for _, m := range modules { - if m.Type == "sentence_transformers.models.Pooling" { + switch m.Type { + case "sentence_transformers.models.Pooling": pooling = m.Path - break + case "sentence_transformers.models.Normalize": + p.normalizeEmbeddings = true } } @@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV { kv["general.architecture"] = "bert" kv["bert.attention.causal"] = false kv["bert.pooling_type"] = p.PoolingType + kv["bert.normalize_embeddings"] = p.normalizeEmbeddings kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) diff --git a/ml/backend.go b/ml/backend.go index 154a0f1b5..ef7564782 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -416,6 +416,7 @@ type Tensor interface { AddID(ctx Context, t2, ids Tensor) Tensor Softmax(ctx Context) Tensor + L2Norm(ctx Context, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 931386d56..d5e2e9c9c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { } } +func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)), + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) if w != nil { diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go new file mode 100644 index 000000000..f84690c41 --- /dev/null +++ b/ml/nn/pooling/pooling.go @@ -0,0 +1,36 @@ +package pooling + +import ( + "github.com/ollama/ollama/ml" +) + +type Type uint32 + +const ( + TypeNone Type = iota + TypeMean + TypeCLS + TypeLast + TypeRank + + TypeUnknown = 0xFFFFFFFE + TypeUnspecified = 0xFFFFFFFF +) + +func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor { + switch poolingType { + case TypeNone: + return hiddenStates + case TypeMean: + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) + return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + case TypeCLS: + return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + case TypeLast: + panic("not implemented") + case TypeRank: + panic("not implemented") + default: + panic("not implemented") + } +} diff --git a/model/model.go b/model/model.go index 3a72f09aa..efef71d8b 100644 --- a/model/model.go +++ b/model/model.go @@ -24,7 +24,11 @@ import ( "github.com/ollama/ollama/model/input" ) -var ErrNoVisionModel = errors.New("this model is missing data required for image input") +var ( + ErrNoVisionModel = errors.New("this model is missing data required for image input") + ErrUnsupportedModel = errors.New("model not supported") + ErrUnsupportedTokenizer = errors.New("tokenizer not supported") +) // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { @@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { vv = vv.Elem() } - vv = vv.Elem() + vv = reflect.Indirect(vv) if v.IsNil() { vv = reflect.New(v.Type().Elem()).Elem() } diff --git a/model/models/bert/model.go b/model/models/bert/model.go new file mode 100644 index 000000000..fd1dbd773 --- /dev/null +++ b/model/models/bert/model.go @@ -0,0 +1,181 @@ +package bert + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TypeEmbedding *nn.Embedding `gguf:"token_types"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"` + + Layers []EncoderLayer `gguf:"blk"` + + Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)))) + hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) + } + + hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) + if m.normalize { + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + } + + return hiddenStates, nil +} + +type EncoderLayer struct { + *Attention + AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"` + + *MLP + MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"` +} + +func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + // Attention + residual := hiddenStates + hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // MLP + residual = hiddenStates + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + + return hiddenStates +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"` + + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"` + + Value *nn.Linear `gguf:"attn_v"` + + Output *nn.Linear `gguf:"attn_output"` +} + +func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := a.Query.Forward(ctx, hiddenStates) + if a.QueryNorm != nil { + query = a.QueryNorm.Forward(ctx, query, opts.eps) + } + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + + key := a.Key.Forward(ctx, hiddenStates) + if a.KeyNorm != nil { + key = a.KeyNorm.Forward(ctx, key, opts.eps) + } + key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + value := a.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + return a.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type Options struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int + poolingType pooling.Type + eps float32 + normalize bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +func New(c fs.Config) (model.Model, error) { + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model", "bert") { + case "bert": + processor = model.NewWordPiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.cls_token_id"), + c.Uint("tokenizer.ggml.bos_token_id"), + )), + }, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), + EOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.separator_token_id"), + //nolint:misspell + // NOTE: "seperator_token_id" is a typo in model metadata but we need to + // support it for compatibility. + c.Uint("tokenizer.ggml.seperator_token_id"), + c.Uint("tokenizer.ggml.eos_token_id"), + )), + }, + }, + ) + default: + return nil, model.ErrUnsupportedTokenizer + } + + return &Model{ + TextProcessor: processor, + Layers: make([]EncoderLayer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: pooling.Type(c.Uint("pooling_type")), + normalize: c.Bool("normalize_embeddings", true), + }, + }, nil +} + +func init() { + model.Register("bert", New) + model.Register("bert_embed", New) +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 84c89e1fe..8ccb9f92a 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -24,7 +24,7 @@ type Options struct { type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -40,7 +40,7 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 7d1e269ff..395a0344d 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -1,48 +1,38 @@ package gemma3 import ( - "errors" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type embedModel struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel - PoolingType uint32 + poolingType pooling.Type Dense [2]*nn.Linear `gguf:"dense"` } func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - - switch m.PoolingType { - case 0: // None - case 1: // Mean - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - default: - return nil, errors.New("unsupported pooling type") - } - + hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) for _, dense := range m.Dense { hiddenStates = dense.Forward(ctx, hiddenStates) } - + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) return hiddenStates, nil } func newEmbedModel(c fs.Config) (model.Model, error) { m := &embedModel{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -60,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) { }, ), TextModel: newTextModel(c), - PoolingType: c.Uint("pooling_type", 0), + poolingType: pooling.Type(c.Uint("pooling_type", 0)), } m.Cache = kvcache.NewWrapperCache( diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 5c92b6bf9..27da889e4 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,7 +16,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *VisionModel `gguf:"v"` *TextModel @@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go index 6e83a9724..e59e3193f 100644 --- a/model/models/gemma3n/model.go +++ b/model/models/gemma3n/model.go @@ -10,7 +10,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel } @@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func New(c fs.Config) (model.Model, error) { m := Model{ TextModel: newTextModel(c), - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/models.go b/model/models/models.go index c880a4720..cc9980789 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -1,6 +1,7 @@ package models import ( + _ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 827ce00d9..db07beee9 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -12,18 +12,18 @@ import ( const spmWhitespaceSep = "▁" -type SentencePieceModel struct { +type SentencePiece struct { maxTokenLen int vocab *Vocabulary } -var _ TextProcessor = (*SentencePieceModel)(nil) +var _ TextProcessor = (*SentencePiece)(nil) -func (spm SentencePieceModel) Vocabulary() *Vocabulary { +func (spm SentencePiece) Vocabulary() *Vocabulary { return spm.vocab } -func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { +func NewSentencePiece(vocab *Vocabulary) SentencePiece { logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} @@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "max token len", maxTokenLen) - return SentencePieceModel{ + return SentencePiece{ maxTokenLen: maxTokenLen, vocab: vocab, } } -func (spm SentencePieceModel) Is(id int32, special Special) bool { +func (spm SentencePiece) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { +func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { id := spm.vocab.Encode(special) @@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} { return item } -func (spm SentencePieceModel) Decode(ids []int32) (string, error) { +func (spm SentencePiece) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { data := spm.vocab.Decode(id) diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go index 50ac26787..8f4570c17 100644 --- a/model/sentencepiece_test.go +++ b/model/sentencepiece_test.go @@ -12,7 +12,7 @@ import ( "github.com/ollama/ollama/convert/sentencepiece" ) -func loadSentencePieceVocab(t *testing.T) SentencePieceModel { +func loadSentencePieceVocab(t *testing.T) SentencePiece { t.Helper() bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) @@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(&v) + return NewSentencePiece(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) { }) } -func TestSentencePieceModelDecodeByteTokens(t *testing.T) { +func TestSentencePieceDecodeByteTokens(t *testing.T) { vocab := &Vocabulary{ Values: []string{ "normal", @@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) { Scores: []float32{0, 0, 0, 0, 0}, } - spm := NewSentencePieceModel(vocab) + spm := NewSentencePiece(vocab) tests := []struct { name string diff --git a/model/wordpiece.go b/model/wordpiece.go new file mode 100644 index 000000000..e8d5e848a --- /dev/null +++ b/model/wordpiece.go @@ -0,0 +1,167 @@ +package model + +import ( + "fmt" + "iter" + "strings" + "unicode" + + "github.com/ollama/ollama/logutil" +) + +type WordPiece struct { + vocab *Vocabulary +} + +// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. +// this differs from original word piece which uses "##" to indicate subwords. +const ggmlPrefix = "▁" + +var wordPieceReplacer = strings.NewReplacer( + " .", ".", + " ?", "?", + " !", "!", + " ,", ",", + " ' ", "'", + " n't", "n't", + " 'm", "'m", + " do not", " don't", + " 's", "'s", + " 've", "'ve", + " 're", "'re", +) + +// Decode implements TextProcessor. +func (wpm WordPiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for i, id := range ids { + if id < 0 || int(id) >= len(wpm.vocab.Values) { + return "", fmt.Errorf("invalid token id: %d", id) + } + + var separator string + piece := wpm.vocab.Values[id] + if i > 0 && + (strings.HasPrefix(piece, ggmlPrefix) || + (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { + separator = " " + } + + sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) + } + + return sb.String(), nil +} + +// words splits a string into words, treating CJK characters as separate words. +// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. +func (wpm WordPiece) words(s string) iter.Seq[string] { + return func(yield func(string) bool) { + runes := make([]rune, 0, len(s)*3) + for _, r := range s { + switch { + case r >= 0x4E00 && r <= 0x9FFF, + r >= 0x3400 && r <= 0x4DBF, + r >= 0x20000 && r <= 0x2A6DF, + r >= 0x2A700 && r <= 0x2B73F, + r >= 0x2B740 && r <= 0x2B81F, + r >= 0x2B820 && r <= 0x2CEAF, + r >= 0xF900 && r <= 0xFAFF, + r >= 0x2F800 && r <= 0x2FA1F: + runes = append(runes, ' ', r, ' ') + default: + runes = append(runes, r) + } + } + + for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { + // split on but keep punctuation + var start int + for start < len(w) { + end := strings.IndexFunc(w[start:], unicode.IsPunct) + if end < 0 { + end = len(w) - start + } else if end == 0 { + end = 1 + } + + if !yield(w[start : start+end]) { + return + } + + start += end + } + } + } +} + +// Encode implements TextProcessor. +func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { + var ids []int32 + + // TODO: use [UNK] from config + unk := wpm.vocab.Encode("[UNK]") + for word := range wpm.words(s) { + var start int + var pieces []int32 + for start < len(word) { + end := len(word) + + var piece int32 + for start < end { + subword := word[start:end] + if start == 0 { + subword = ggmlPrefix + subword + } + + // TODO: some models might not want [ToLower] + piece = wpm.vocab.Encode(strings.ToLower(subword)) + if piece >= 0 { + break + } + + end-- + } + + if piece < 0 { + // Unknown token + pieces = pieces[:0] + break + } + + pieces = append(pieces, piece) + start = end + } + + if len(pieces) > 0 { + ids = append(ids, pieces...) + } else { + ids = append(ids, unk) + } + } + + if addSpecial && len(ids) > 0 { + ids = wpm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +// Is implements TextProcessor. +func (wpm WordPiece) Is(id int32, special Special) bool { + return wpm.vocab.Is(id, special) +} + +// Vocabulary implements TextProcessor. +func (wpm WordPiece) Vocabulary() *Vocabulary { + return wpm.vocab +} + +var _ TextProcessor = (*WordPiece)(nil) + +func NewWordPiece(vocab *Vocabulary) WordPiece { + return WordPiece{ + vocab: vocab, + } +} diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go new file mode 100644 index 000000000..258fbffcb --- /dev/null +++ b/model/wordpiece_test.go @@ -0,0 +1,51 @@ +package model + +import ( + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWordPiece(t *testing.T) { + wpm := NewWordPiece( + &Vocabulary{ + Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, + AddBOS: true, + AddEOS: true, + BOS: []int32{1}, + EOS: []int32{2}, + }) + + ids, err := wpm.Encode("Hello world!", true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { + t.Errorf("unexpected ids (-want +got):\n%s", diff) + } + + words, err := wpm.Decode(ids) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} + +func TestWordPieceWords(t *testing.T) { + var wpm WordPiece + + basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) + if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } + + chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) + if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} diff --git a/server/routes.go b/server/routes.go index 5114cb74f..739ce69da 100644 --- a/server/routes.go +++ b/server/routes.go @@ -488,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { } truncate := true - if req.Truncate != nil && !*req.Truncate { truncate = false } @@ -555,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- + } + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) From 93c64ea1b1941bc6077ff1cf48c8eac902052f77 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 15 Sep 2025 15:45:35 -0700 Subject: [PATCH 21/32] doc: show how to clear the cgo cache (#12298) --- docs/development.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/development.md b/docs/development.md index 9726b5d91..ff07b5fb6 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository: go run . serve ``` +> [!NOTE] +> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first. + + ## macOS (Apple Silicon) macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required. From a1cff89b30b6f0343ed79051b4a4868dc76bc2ad Mon Sep 17 00:00:00 2001 From: Beshoy Girgis <323600+egyptianbman@users.noreply.github.com> Date: Tue, 16 Sep 2025 09:47:06 -0500 Subject: [PATCH 22/32] fix: fix CUDA detection for older GPUs (#12300) Prioritize GPU compute capability over driver version to ensure Pascal GPUs (CC 6.1) use compatible CUDA v12 libraries instead of v13. --- discover/cuda_common.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/discover/cuda_common.go b/discover/cuda_common.go index ca008af63..3c7a92114 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -45,10 +45,18 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } + // Check GPU compute capability FIRST + isOldGPU := gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) + if isOldGPU { + // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) + return "v12" + } + + // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA if gpuInfo.DriverMajor < 13 { // The detected driver is older than 580 (Aug 2025) // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance - if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) { + if !isOldGPU { slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) } return "v12" From c253433d68c9337101ec089d55e9f32fb8921a45 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Sep 2025 09:48:42 -0700 Subject: [PATCH 23/32] embed: cleanup (#12299) * cleanup * use pooling.TypeNone * pooling test --- ml/nn/pooling/pooling.go | 30 +++++---- ml/nn/pooling/pooling_test.go | 79 ++++++++++++++++++++++++ model/model.go | 4 +- model/models/bert/{model.go => embed.go} | 2 +- model/models/gemma3/embed.go | 2 +- runner/ollamarunner/runner.go | 6 +- 6 files changed, 104 insertions(+), 19 deletions(-) create mode 100644 ml/nn/pooling/pooling_test.go rename model/models/bert/{model.go => embed.go} (98%) diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go index f84690c41..63b63b3af 100644 --- a/ml/nn/pooling/pooling.go +++ b/ml/nn/pooling/pooling.go @@ -11,26 +11,32 @@ const ( TypeMean TypeCLS TypeLast - TypeRank - - TypeUnknown = 0xFFFFFFFE - TypeUnspecified = 0xFFFFFFFF ) -func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor { - switch poolingType { - case TypeNone: - return hiddenStates +func (t Type) String() string { + switch t { + case TypeMean: + return "Mean" + case TypeCLS: + return "CLS" + case TypeLast: + return "Last" + default: + return "Unknown" + } +} + +func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + switch t { case TypeMean: hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) case TypeCLS: return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) case TypeLast: - panic("not implemented") - case TypeRank: - panic("not implemented") + hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) + return hiddenStates default: - panic("not implemented") + panic("unknown pooling type") } } diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go new file mode 100644 index 000000000..c80019459 --- /dev/null +++ b/ml/nn/pooling/pooling_test.go @@ -0,0 +1,79 @@ +package pooling_test + +import ( + "bytes" + "os" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/discover" + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/ml/nn/pooling" +) + +func setup(tb testing.TB, n int) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := fsggml.WriteGGUF(f, fsggml.KV{ + "general.architecture": "test", + "test.block_count": uint32(1), + }, []*fsggml.Tensor{ + {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))}, + }); err != nil { + tb.Fatal(err) + } + + var gpuLayers ml.GPULayersList + if gpus := discover.GetGPUInfo(); len(gpus) > 0 { + gpuLayers = append(gpuLayers, ml.GPULayers{ + ID: gpus[0].ID, + Layers: slices.Collect(func(yield func(int) bool) { + for i := range n { + if !yield(i) { + return + } + } + }), + }) + } + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers}) + if err != nil { + tb.Fatal(err) + } + + return b +} + +func TestForward(t *testing.T) { + cases := map[pooling.Type][]float32{ + pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11}, + pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7}, + pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15}, + } + for typ, want := range cases { + t.Run(typ.String(), func(t *testing.T) { + b := setup(t, 99) + defer b.Close() + + ctx := b.NewContext() + defer ctx.Close() + + tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2) + tt = typ.Forward(ctx, tt) + + ctx.Forward(tt).Compute(tt) + if diff := cmp.Diff(want, tt.Floats()); diff != "" { + t.Error(diff) + } + }) + } +} diff --git a/model/model.go b/model/model.go index efef71d8b..5493a4e63 100644 --- a/model/model.go +++ b/model/model.go @@ -5,7 +5,6 @@ import ( "fmt" _ "image/jpeg" _ "image/png" - "math" "os" "reflect" "strconv" @@ -21,6 +20,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model/input" ) @@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { } arch := b.Config().Architecture() - if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 { + if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone { arch = arch + "_embed" } diff --git a/model/models/bert/model.go b/model/models/bert/embed.go similarity index 98% rename from model/models/bert/model.go rename to model/models/bert/embed.go index fd1dbd773..166c11e13 100644 --- a/model/models/bert/model.go +++ b/model/models/bert/embed.go @@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) } - hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) if m.normalize { hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) } diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 395a0344d..525547767 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -22,7 +22,7 @@ type embedModel struct { func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType) + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) for _, dense := range m.Dense { hiddenStates = dense.Forward(ctx, hiddenStates) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 3a32384f8..480cfc19b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -11,7 +11,6 @@ import ( "image" "log" "log/slog" - "math" "net" "net/http" "os" @@ -32,6 +31,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" @@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { func (s *Server) run(ctx context.Context) { s.ready.Wait() - supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 + supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone var activeBatch batchState for { @@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 { + if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone { http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) return } From ad95d5b30be7b1e8d845697722a05f9e5d52b8c7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Sep 2025 09:51:19 -0700 Subject: [PATCH 24/32] use split activations when possible (#12293) * use ggml_*_split activations when possible * forward qkv --- ml/backend.go | 11 +++++----- ml/backend/ggml/ggml.go | 31 ++++++++++++++++++--------- ml/nn/attention.go | 2 ++ model/models/gemma2/model.go | 2 +- model/models/gemma3/model_text.go | 2 +- model/models/gemma3n/model_text.go | 5 ++--- model/models/gptoss/model.go | 2 +- model/models/llama/model.go | 2 +- model/models/llama4/model_text.go | 16 +++++++------- model/models/mistral3/model_text.go | 2 +- model/models/mistral3/model_vision.go | 2 +- model/models/mllama/model_text.go | 2 +- model/models/qwen2/model.go | 2 +- model/models/qwen25vl/model_text.go | 2 +- model/models/qwen25vl/model_vision.go | 3 +-- model/models/qwen3/model.go | 23 +++++++++----------- 16 files changed, 59 insertions(+), 50 deletions(-) diff --git a/ml/backend.go b/ml/backend.go index ef7564782..455715b0d 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -430,12 +430,13 @@ type Tensor interface { Sin(ctx Context) Tensor Cos(ctx Context) Tensor Tanh(ctx Context) Tensor - GELU(ctx Context) Tensor - QuickGELU(ctx Context) Tensor - SILU(ctx Context) Tensor - RELU(ctx Context) Tensor + GELU(ctx Context, up ...Tensor) Tensor + SILU(ctx Context, up ...Tensor) Tensor + RELU(ctx Context, up ...Tensor) Tensor Sigmoid(ctx Context) Tensor - SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor + + // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit] + SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor Reshape(ctx Context, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index d5e2e9c9c..49dc3e1ab 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1431,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } -func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t), +func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } } -} - -func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) RELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { +func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)), diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 21b4a28ae..94dbde0b0 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache } func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + ctx.Forward(query) if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) @@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) } + ctx.Forward(key, value) if cache != nil { cache.Put(ctx, key, value) } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 8ccb9f92a..96ace7c74 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -138,7 +138,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 5e515a927..d38746dc8 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -123,7 +123,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index eeb9ab028..2682a45f7 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position } active = d.PerLayerInputGate.Forward(ctx, active) - active = active.GELU(ctx) - active = active.Mul(ctx, perLayerInput) + active = active.GELU(ctx, perLayerInput) active = d.PerLayerProjection.Forward(ctx, active) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) @@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx) } - hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates) + hiddenStates = hiddenStates.GELU(ctx, upStates) hiddenStates = mlp.Down.Forward(ctx, hiddenStates) return hiddenStates } diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index a74f76487..8456ea5f7 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts * up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) } - hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7) + hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 51273a014..572c687a9 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -118,7 +118,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 045ab403f..dbe6bba7a 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -58,14 +58,14 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } type TextExperts struct { - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { @@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Mul(ctx, scores) - upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) - gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) - downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + upStates := e.Up.Forward(ctx, hiddenStates, experts) + gateStates := e.Gate.Forward(ctx, hiddenStates, experts) + downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { @@ -96,7 +96,7 @@ type TextSharedExpert struct { } func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 19c36f9fe..132d1756c 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -65,7 +65,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 65bdcff2a..3bfb8c90a 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -51,7 +51,7 @@ type VisionMLP struct { } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 47a518ced..cb18f0878 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -58,7 +58,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 93a502612..5a8bea29e 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -59,7 +59,7 @@ type MLP struct { } func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 4b6bc1666..4f4e1effd 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -90,7 +90,7 @@ type MLP struct { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { // Apply SwiGLU activation gating - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) // Project back to hidden dimension return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 4d7afaa14..3dd60e3ba 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -100,8 +100,7 @@ type VisionMLP struct { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { // Using activation as specified in config (likely GELU or SiLU/Swish) gateOutput := mlp.Gate.Forward(ctx, hiddenStates) - upOutput := mlp.Up.Forward(ctx, hiddenStates) - hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput) + hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index ec2adaa7d..3f86d0236 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -30,10 +30,10 @@ func (o Options) headDim() int { } type Attention struct { - QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Query *nn.Linear `gguf:"attn_q"` - KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` } @@ -65,10 +65,10 @@ type MLP interface { } type sparse struct { - Router *nn.Linear `gguf:"ffn_gate_inp"` - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { @@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) - upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts)) - hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) - hiddenStates = hiddenStates.SILU(ctx) - hiddenStates = hiddenStates.Mul(ctx, upStates) - - experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) @@ -111,7 +107,8 @@ type dense struct { } func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates). + SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } From b225508c9b8f9118b57798c76f51f0e52835fabc Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Sep 2025 16:18:07 -0700 Subject: [PATCH 25/32] logutil: fix source field (#12279) --- logutil/logutil.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/logutil/logutil.go b/logutil/logutil.go index fff277b84..00daf6a6e 100644 --- a/logutil/logutil.go +++ b/logutil/logutil.go @@ -5,6 +5,8 @@ import ( "io" "log/slog" "path/filepath" + "runtime" + "time" ) const LevelTrace slog.Level = -8 @@ -29,10 +31,18 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger { })) } +type key string + func Trace(msg string, args ...any) { - slog.Log(context.TODO(), LevelTrace, msg, args...) + TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...) } func TraceContext(ctx context.Context, msg string, args ...any) { - slog.Log(ctx, LevelTrace, msg, args...) + if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) { + skip, _ := ctx.Value(key("skip")).(int) + pc, _, _, _ := runtime.Caller(1 + skip) + record := slog.NewRecord(time.Now(), LevelTrace, msg, pc) + record.Add(args...) + logger.Handler().Handle(ctx, record) + } } From 05d53457af8fda79c0e3884f316144d6c2aed5b9 Mon Sep 17 00:00:00 2001 From: russcoss Date: Tue, 16 Sep 2025 20:14:21 -0400 Subject: [PATCH 26/32] refactor: use the built-in max/min to simplify the code (#12280) Signed-off-by: russcoss --- runner/llamarunner/cache.go | 7 +------ runner/ollamarunner/cache.go | 7 +------ server/internal/internal/backoff/backoff.go | 5 +---- server/sched.go | 5 +---- 4 files changed, 4 insertions(+), 20 deletions(-) diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index 44b246134..9ed1c2924 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { targetFree = max(targetFree, 1) currentFree := c.numCtx - inputLen - discard := targetFree - currentFree - if discard < 0 { - discard = 0 - } - - return discard + return max(targetFree-currentFree, 0) } type ErrReprocessInputs struct { diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index f558f7b87..a3ffc3bd2 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -242,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { targetFree = max(targetFree, 1) currentFree := c.numCtx - inputLen - discard := targetFree - currentFree - if discard < 0 { - discard = 0 - } - - return discard + return max(targetFree-currentFree, 0) } type ErrReprocessInputs struct { diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go index 1f0634f7c..08b4ed7f9 100644 --- a/server/internal/internal/backoff/backoff.go +++ b/server/internal/internal/backoff/backoff.go @@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { // n^2 backoff timer is a little smoother than the // common choice of 2^n. - d := time.Duration(n*n) * 10 * time.Millisecond - if d > maxBackoff { - d = maxBackoff - } + d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) // Randomize the delay between 0.5-1.5 x msec, in order // to prevent accidental "thundering herd" problems. d = time.Duration(float64(d) * (rand.Float64() + 0.5)) diff --git a/server/sched.go b/server/sched.go index c501c0e85..74aa406af 100644 --- a/server/sched.go +++ b/server/sched.go @@ -382,10 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // (if any). Returns whether the scheduler needs to evict a model to make this one fit. func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool { - numParallel := int(envconfig.NumParallel()) - if numParallel < 1 { - numParallel = 1 - } + numParallel := max(int(envconfig.NumParallel()), 1) // Embedding models should always be loaded with parallel=1 if req.model.CheckCapabilities(model.CapabilityCompletion) != nil { From a417ac97ee685e6197dc35cec24bba6cffee1f94 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 17 Sep 2025 09:48:21 -0700 Subject: [PATCH 27/32] prefer ollama engine for qwen3 (#12310) --- fs/ggml/ggml.go | 1 + 1 file changed, 1 insertion(+) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 6b582b499..5da902bcb 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -243,6 +243,7 @@ func (kv KV) OllamaEngineRequired() bool { "gemma3", "gemma3n", "mistral3", + "qwen3", "llama4", "mllama", "qwen25vl", From 564b558c92973ae9eda4ad585359e7f39b2dbff2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 17 Sep 2025 12:12:21 -0700 Subject: [PATCH 28/32] fix(llama): other llama flavours (#12308) * fix(llama): rope scale * spm llama * skip moe models * cleanup --- model/models/gemma2/model.go | 6 +-- model/models/gemma3/model_text.go | 6 +-- model/models/gemma3n/model_text.go | 8 ++-- model/models/llama/model.go | 74 ++++++++++++++++------------- model/models/llama4/model_text.go | 8 ++-- model/models/mistral3/model_text.go | 8 ++-- model/models/mllama/model_text.go | 8 ++-- model/models/qwen2/model.go | 8 ++-- model/models/qwen25vl/model_text.go | 8 ++-- model/models/qwen3/model.go | 8 ++-- 10 files changed, 75 insertions(+), 67 deletions(-) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 96ace7c74..81d41f2ab 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) { attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base", 10000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), }, @@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index d38746dc8..c2a526080 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -53,7 +53,7 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), }, } @@ -84,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -95,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index 2682a45f7..d0e9a026a 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.ropeBaseLocal } - return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } type TextScaledWordEmbedding struct { @@ -256,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = attn.QueryNorm.Forward(ctx, query, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) var key, value ml.Tensor if !sharedKV { key = attn.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = attn.KeyNorm.Forward(ctx, key, opts.eps) - key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) value = attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) @@ -349,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeBase: c.Float("rope.freq_base", 1_000_000), ropeBaseLocal: c.Float("rope.freq_base_local", 10_000), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), activationSparsityScale: c.Floats("activation_sparsity_scale"), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 572c687a9..f6ec02273 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -2,7 +2,6 @@ package llama import ( "cmp" - "fmt" "math" "github.com/ollama/ollama/fs" @@ -23,51 +22,60 @@ type Options struct { type Model struct { model.Base - model.BytePairEncoding + model.TextProcessor TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` - *Options + Options } func New(c fs.Config) (model.Model, error) { - // This model currently only supports the gpt2 tokenizer - if c.String("tokenizer.ggml.model") == "llama" { - return nil, fmt.Errorf("unsupported tokenizer: llama") + if c.Uint("expert_count") > 0 { + // TODO: support mixtures of experts + return nil, model.ErrUnsupportedModel } - // Best effort detection of library/deepseek-coder model(s) which are incompatible - if c.String("general.name") == "deepseek-ai" { - return nil, fmt.Errorf("unsupported model: %s", c.String("general.name")) - } - m := Model{ - BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Ints("tokenizer.ggml.token_type"), - Merges: c.Strings("tokenizer.ggml.merges"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), - }, + + var processor model.TextProcessor + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., ), - Layers: make([]Layer, c.Uint("block_count")), - Options: &Options{ + } + switch c.String("tokenizer.ggml.model") { + case "gpt2": + processor = model.NewBytePairEncoding( + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + &vocabulary, + ) + case "llama": + processor = model.NewSentencePiece(&vocabulary) + default: + return nil, model.ErrUnsupportedTokenizer + } + + m := Model{ + TextProcessor: processor, + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeBase: c.Float("rope.freq_base", 1e5), + ropeScale: c.Float("rope.scaling.factor", 1), }, } @@ -98,8 +106,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -108,7 +116,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { @@ -163,7 +171,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { outputs = batch.Outputs } - hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index dbe6bba7a..e0f932600 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) } if opts.useQKNorm { @@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel { numExpertsUsed: int(c.Uint("expert_used_count")), ropeDim: int(c.Uint("rope.dimension_count")), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), eps: c.Float("attention.layer_norm_rms_epsilon"), interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), noRopeInterval: int(c.Uint("no_rope_interval", 4)), @@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 132d1756c..d2e2eac6c 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil } type MLP struct { @@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), }, } } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index cb18f0878..65f0a8278 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil } return key, nil @@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 5a8bea29e..5a3458378 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, value := attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } func New(c fs.Config) (model.Model, error) { @@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) { headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), eps: c.Float("attention.layer_norm_rms_epsilon"), }, } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 4f4e1effd..e6c6e6c19 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel { originalContextLength: int(c.Uint("context_length", 128000)), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), }, } @@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil } // MLP implements the feed-forward network component with SwiGLU activation diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 3f86d0236..c4e0b2d80 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) @@ -173,7 +173,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } var _ model.Model = (*Model)(nil) @@ -213,7 +213,7 @@ func New(c fs.Config) (model.Model, error) { valueLength: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), numExperts: int(c.Uint("expert_count")), numExpertsUsed: int(c.Uint("expert_used_count")), normTopKProb: c.Bool("norm_top_k_prob", true), From 9c5bf342bc34f94de9aa4a171d726e6b341a91e6 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 17 Sep 2025 13:05:09 -0700 Subject: [PATCH 29/32] fix: multi-cuda version skew (#12318) Ensure that in a version skewed multi-cuda setup we use the lowest version for all GPUs --- discover/cuda_common.go | 19 +++++++++---------- discover/gpu.go | 28 ++++++++++++++++++---------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/discover/cuda_common.go b/discover/cuda_common.go index 3c7a92114..a2c43420e 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -16,7 +16,7 @@ import ( // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. var CudaTegra string = os.Getenv("JETSON_JETPACK") -func cudaVariant(gpuInfo CudaGPUInfo) string { +func cudaVariant(gpuInfos []CudaGPUInfo) string { if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { if CudaTegra != "" { ver := strings.Split(CudaTegra, ".") @@ -45,20 +45,19 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } - // Check GPU compute capability FIRST - isOldGPU := gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) - if isOldGPU { - // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) - return "v12" + // Check GPU compute capability FIRST, lowest common denominator if multi-gpu + for _, gpuInfo := range gpuInfos { + if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) { + // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) + return "v12" + } } // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA - if gpuInfo.DriverMajor < 13 { + if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 { // The detected driver is older than 580 (Aug 2025) // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance - if !isOldGPU { - slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) - } + slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor)) return "v12" } return "v13" diff --git a/discover/gpu.go b/discover/gpu.go index b09626118..a39bc7c3d 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList { gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMinor = driverMinor - variant := cudaVariant(gpuInfo) - // Start with our bundled libraries - if variant != "" { - variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant) - if _, err := os.Stat(variantPath); err == nil { - // Put the variant directory first in the search path to avoid runtime linking to the wrong library - gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...) - } - } gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.Variant = variant if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { unsupportedGPUs = append(unsupportedGPUs, @@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList { // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... cudaGPUs = append(cudaGPUs, gpuInfo) } + // Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths + variant := cudaVariant(cudaGPUs) + var variantPath string + // Start with our bundled libraries + if variant != "" { + variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant) + if _, err := os.Stat(variantPath); err != nil { + variantPath = "" + } + } + + for i := range cudaGPUs { + cudaGPUs[i].Variant = variant + if variantPath != "" { + // Put the variant directory first in the search path to avoid runtime linking to the wrong library + cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...) + } + } } // Intel From 8b894933a73f4c477ba1401299c29f3553b622ee Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 17 Sep 2025 14:40:53 -0700 Subject: [PATCH 30/32] engine: add remote proxy (#12307) --- api/client.go | 25 +++- api/types.go | 125 +++++++++++++--- auth/auth.go | 13 ++ cmd/cmd.go | 155 ++++++++++++++++++-- cmd/cmd_test.go | 8 +- envconfig/config.go | 12 ++ server/create.go | 152 ++++++++++++++++--- server/create_test.go | 151 +++++++++++++++++++ server/images.go | 46 ++++-- server/routes.go | 277 +++++++++++++++++++++++++++++++---- server/routes_create_test.go | 74 ++++++++++ server/routes_test.go | 10 +- 12 files changed, 948 insertions(+), 100 deletions(-) diff --git a/api/client.go b/api/client.go index 7cc2acb3d..20e6d7957 100644 --- a/api/client.go +++ b/api/client.go @@ -222,7 +222,17 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f return fmt.Errorf("unmarshal: %w", err) } - if response.StatusCode >= http.StatusBadRequest { + if response.StatusCode == http.StatusUnauthorized { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + return AuthorizationError{ + StatusCode: response.StatusCode, + Status: response.Status, + PublicKey: pubKey, + } + } else if response.StatusCode >= http.StatusBadRequest { return StatusError{ StatusCode: response.StatusCode, Status: response.Status, @@ -428,3 +438,16 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } + +// Signout will disconnect an ollama instance from ollama.com +func (c *Client) Signout(ctx context.Context, encodedKey string) error { + return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) +} + +func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { + var resp UserResponse + if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/api/types.go b/api/types.go index df3504c3b..5b8e034c2 100644 --- a/api/types.go +++ b/api/types.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -36,6 +38,19 @@ func (e StatusError) Error() string { } } +type AuthorizationError struct { + StatusCode int + Status string + PublicKey string `json:"public_key"` +} + +func (e AuthorizationError) Error() string { + if e.Status != "" { + return e.Status + } + return "something went wrong, please see the ollama server logs for details" +} + // ImageData represents the raw binary data of an image file. type ImageData []byte @@ -313,14 +328,29 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` - DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + // Model is the model name that generated the response. + Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + + // CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Message contains the message or part of a message from the model. + Message Message `json:"message"` + + // Done specifies if the response is complete. Done bool `json:"done"` + // DoneReason is the reason the model stopped generating text. + DoneReason string `json:"done_reason,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + Metrics } @@ -425,20 +455,47 @@ type EmbeddingResponse struct { // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { - Model string `json:"model"` - Stream *bool `json:"stream,omitempty"` + // Model is the model name to create. + Model string `json:"model"` + + // Stream specifies whether the response is streaming; it is true by default. + Stream *bool `json:"stream,omitempty"` + + // Quantize is the quantization format for the model; leave blank to not change the quantization level. Quantize string `json:"quantize,omitempty"` - From string `json:"from,omitempty"` - Files map[string]string `json:"files,omitempty"` - Adapters map[string]string `json:"adapters,omitempty"` - Template string `json:"template,omitempty"` - License any `json:"license,omitempty"` - System string `json:"system,omitempty"` - Parameters map[string]any `json:"parameters,omitempty"` - Messages []Message `json:"messages,omitempty"` - Renderer string `json:"renderer,omitempty"` - Parser string `json:"parser,omitempty"` + // From is the name of the model or file to use as the source. + From string `json:"from,omitempty"` + + // RemoteHost is the URL of the upstream ollama API for the model (if any). + RemoteHost string `json:"remote_host,omitempty"` + + // Files is a map of files include when creating the model. + Files map[string]string `json:"files,omitempty"` + + // Adapters is a map of LoRA adapters to include when creating the model. + Adapters map[string]string `json:"adapters,omitempty"` + + // Template is the template used when constructing a request to the model. + Template string `json:"template,omitempty"` + + // License is a string or list of strings for licenses. + License any `json:"license,omitempty"` + + // System is the system prompt for the model. + System string `json:"system,omitempty"` + + // Parameters is a map of hyper-parameters which are applied to the model. + Parameters map[string]any `json:"parameters,omitempty"` + + // Messages is a list of messages added to the model before chat and generation requests. + Messages []Message `json:"messages,omitempty"` + + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + + // Info is a map of additional information for the model + Info map[string]any `json:"info,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -480,6 +537,8 @@ type ShowResponse struct { Parser string `json:"parser,omitempty"` Details ModelDetails `json:"details,omitempty"` Messages []Message `json:"messages,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"` Tensors []Tensor `json:"tensors,omitempty"` @@ -538,12 +597,14 @@ type ProcessResponse struct { // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { - Name string `json:"name"` - Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` - Details ModelDetails `json:"details,omitempty"` + Name string `json:"name"` + Model string `json:"model"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` } // ProcessModelResponse is a single model description in [ProcessResponse]. @@ -567,6 +628,12 @@ type GenerateResponse struct { // Model is the model name that generated the response. Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + // CreatedAt is the timestamp of the response. CreatedAt time.Time `json:"created_at"` @@ -604,6 +671,18 @@ type ModelDetails struct { QuantizationLevel string `json:"quantization_level"` } +// UserResponse provides information about a user. +type UserResponse struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Bio string `json:"bio,omitempty"` + AvatarURL string `json:"avatarurl,omitempty"` + FirstName string `json:"firstname,omitempty"` + LastName string `json:"lastname,omitempty"` + Plan string `json:"plan,omitempty"` +} + // Tensor describes the metadata for a given tensor. type Tensor struct { Name string `json:"name"` diff --git a/auth/auth.go b/auth/auth.go index e1d854124..61a8626c3 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -19,6 +19,19 @@ import ( const defaultPrivateKey = "id_ed25519" func keyPath() (string, error) { + fileExists := func(fp string) bool { + info, err := os.Stat(fp) + if err != nil { + return false + } + return !info.IsDir() + } + + systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey) + if fileExists(systemPath) { + return systemPath, nil + } + home, err := os.UserHomeDir() if err != nil { return "", err diff --git a/cmd/cmd.go b/cmd/cmd.go index 19f1e192f..294e1662f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "encoding/base64" "encoding/json" "encoding/pem" "errors" @@ -14,6 +15,7 @@ import ( "math" "net" "net/http" + "net/url" "os" "os/signal" "path/filepath" @@ -35,6 +37,7 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" @@ -47,6 +50,8 @@ import ( "github.com/ollama/ollama/version" ) +const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n" + // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { if name == "" { @@ -286,7 +291,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { Think: opts.Think, } - return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) + return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error { + if r.RemoteModel != "" && opts.ShowConnect { + p.StopAndClear() + if strings.HasPrefix(r.RemoteHost, "https://ollama.com") { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel) + } else { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost) + } + } + return nil + }) } func StopHandler(cmd *cobra.Command, args []string) error { @@ -307,9 +322,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { interactive := true opts := runOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]any{}, + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]any{}, + ShowConnect: true, } format, err := cmd.Flags().GetString("format") @@ -367,6 +383,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { } prompts = append([]string{string(in)}, prompts...) + opts.ShowConnect = false opts.WordWrap = false interactive = false } @@ -433,6 +450,21 @@ func RunHandler(cmd *cobra.Command, args []string) error { if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + // the server and the client both have the same public key + if pubKey == sErr.PublicKey { + h, _ := os.Hostname() + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + } + return nil + } return err } @@ -453,6 +485,56 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +func SigninHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + user, err := client.Whoami(cmd.Context()) + if err != nil { + return err + } + + if user != nil && user.Name != "" { + fmt.Printf("You are already signed in as user '%s'\n", user.Name) + fmt.Println() + return nil + } + + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + h, _ := os.Hostname() + fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + + return nil +} + +func SignoutHandler(cmd *cobra.Command, args []string) error { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + err = client.Signout(cmd.Context(), encKey) + if err != nil { + return err + } + fmt.Println("You have signed out of ollama.com") + fmt.Println() + return nil +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -505,7 +587,8 @@ func PushHandler(cmd *cobra.Command, args []string) error { if spinner != nil { spinner.Stop() } - if strings.Contains(err.Error(), "access denied") { + errStr := strings.ToLower(err.Error()) + if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } return err @@ -539,7 +622,14 @@ func ListHandler(cmd *cobra.Command, args []string) error { for _, m := range models.Models { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { - data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) + var size string + if m.RemoteModel != "" { + size = "-" + } else { + size = format.HumanBytes(m.Size) + } + + data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")}) } } @@ -624,8 +714,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { KeepAlive: &api.Duration{Duration: 0}, } if err := loadOrUnloadModel(cmd, opts); err != nil { - if !strings.Contains(err.Error(), "not found") { - return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) + if !strings.Contains(strings.ToLower(err.Error()), "not found") { + fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0]) } } @@ -736,12 +826,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { } tableRender("Model", func() (rows [][]string) { + if resp.RemoteHost != "" { + rows = append(rows, []string{"", "Remote model", resp.RemoteModel}) + rows = append(rows, []string{"", "Remote URL", resp.RemoteHost}) + } + if resp.ModelInfo != nil { arch := resp.ModelInfo["general.architecture"].(string) rows = append(rows, []string{"", "architecture", arch}) - rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))}) - rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)}) - rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)}) + + var paramStr string + if resp.Details.ParameterSize != "" { + paramStr = resp.Details.ParameterSize + } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok { + if f, ok := v.(float64); ok { + paramStr = format.HumanNumber(uint64(f)) + } + } + rows = append(rows, []string{"", "parameters", paramStr}) + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } } else { rows = append(rows, []string{"", "architecture", resp.Details.Family}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) @@ -989,6 +1103,7 @@ type runOptions struct { KeepAlive *api.Duration Think *api.ThinkValue HideThinking bool + ShowConnect bool } type displayResponseState struct { @@ -1544,6 +1659,22 @@ func NewCLI() *cobra.Command { pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") + signinCmd := &cobra.Command{ + Use: "signin", + Short: "Sign in to ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SigninHandler, + } + + signoutCmd := &cobra.Command{ + Use: "signout", + Short: "Sign out from ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SignoutHandler, + } + listCmd := &cobra.Command{ Use: "list", Aliases: []string{"ls"}, @@ -1638,6 +1769,8 @@ func NewCLI() *cobra.Command { stopCmd, pullCmd, pushCmd, + signinCmd, + signoutCmd, listCmd, psCmd, copyCmd, diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index cf5fe7caa..bb793572f 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -304,6 +305,8 @@ func TestDeleteHandler(t *testing.T) { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusNotFound) + errPayload := `{"error":"model '%s' not found"}` + w.Write([]byte(fmt.Sprintf(errPayload, req.Name))) } return } @@ -346,7 +349,7 @@ func TestDeleteHandler(t *testing.T) { } err := DeleteHandler(cmd, []string{"test-model-not-found"}) - if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { + if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") { t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) } } @@ -499,7 +502,7 @@ func TestPushHandler(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) err := json.NewEncoder(w).Encode(map[string]string{ - "error": "access denied", + "error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}", }) if err != nil { t.Fatal(err) @@ -522,6 +525,7 @@ func TestPushHandler(t *testing.T) { defer mockServer.Close() t.Setenv("OLLAMA_HOST", mockServer.URL) + initializeKeypair() cmd := &cobra.Command{} cmd.Flags().Bool("insecure", false, "") diff --git a/envconfig/config.go b/envconfig/config.go index 7fc018870..09243ab95 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) { return loadTimeout } +func Remotes() []string { + var r []string + raw := strings.TrimSpace(Var("OLLAMA_REMOTES")) + if raw == "" { + r = []string{"ollama.com"} + } else { + r = strings.Split(raw, ",") + } + return r +} + func Bool(k string) func() bool { return func() bool { if s := Var(k); s != "" { @@ -270,6 +281,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, + "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/server/create.go b/server/create.go index f08f18b34..19f24ec80 100644 --- a/server/create.go +++ b/server/create.go @@ -10,8 +10,11 @@ import ( "io" "io/fs" "log/slog" + "net" "net/http" + "net/url" "os" + "path" "path/filepath" "slices" "strings" @@ -39,6 +42,14 @@ var ( ) func (s *Server) CreateHandler(c *gin.Context) { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + var r api.CreateRequest if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) { return } + config.Renderer = r.Renderer + config.Parser = r.Parser + for v := range r.Files { if !fs.ValidPath(v) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) @@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) { oldManifest, _ := ParseNamedManifest(name) var baseLayers []*layerGGML + var err error + var remote bool + if r.From != "" { - slog.Debug("create model from model name") + slog.Debug("create model from model name", "from", r.From) fromName := model.ParseName(r.From) if !fromName.IsValid() { ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} return } + if r.RemoteHost != "" { + ru, err := remoteURL(r.RemoteHost) + if err != nil { + ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest} + return + } - ctx, cancel := context.WithCancel(c.Request.Context()) - defer cancel() + config.RemoteModel = r.From + config.RemoteHost = ru + remote = true + } else { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() - baseLayers, err = parseFromModel(ctx, fromName, fn) - if err != nil { - ch <- gin.H{"error": err.Error()} + baseLayers, err = parseFromModel(ctx, fromName, fn) + if err != nil { + ch <- gin.H{"error": err.Error()} + } } } else if r.Files != nil { baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn) @@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) { } var adapterLayers []*layerGGML - if r.Adapters != nil { + if !remote && r.Adapters != nil { adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) if err != nil { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { @@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) { baseLayers = append(baseLayers, adapterLayers...) } - if err := createModel(r, name, baseLayers, fn); err != nil { + // Info is not currently exposed by Modelfiles, but allows overriding various + // config values + if r.Info != nil { + caps, ok := r.Info["capabilities"] + if ok { + switch tcaps := caps.(type) { + case []any: + caps := make([]string, len(tcaps)) + for i, c := range tcaps { + str, ok := c.(string) + if !ok { + continue + } + caps[i] = str + } + config.Capabilities = append(config.Capabilities, caps...) + } + } + + strFromInfo := func(k string) string { + v, ok := r.Info[k] + if ok { + val := v.(string) + return val + } + return "" + } + + vFromInfo := func(k string) float64 { + v, ok := r.Info[k] + if ok { + val := v.(float64) + return val + } + return 0 + } + + config.ModelFamily = strFromInfo("model_family") + if config.ModelFamily != "" { + config.ModelFamilies = []string{config.ModelFamily} + } + + config.BaseName = strFromInfo("base_name") + config.FileType = strFromInfo("quantization_level") + config.ModelType = strFromInfo("parameter_size") + config.ContextLen = int(vFromInfo("context_length")) + config.EmbedLen = int(vFromInfo("embedding_length")) + } + + if err := createModel(r, name, baseLayers, config, fn); err != nil { if errors.Is(err, errBadTemplate) { ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} return @@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) { streamResponse(c, ch) } +func remoteURL(raw string) (string, error) { + // Special‑case: user supplied only a path ("/foo/bar"). + if strings.HasPrefix(raw, "/") { + return (&url.URL{ + Scheme: "http", + Host: net.JoinHostPort("localhost", "11434"), + Path: path.Clean(raw), + }).String(), nil + } + + if !strings.Contains(raw, "://") { + raw = "http://" + raw + } + + if raw == "ollama.com" || raw == "http://ollama.com" { + raw = "https://ollama.com:443" + } + + u, err := url.Parse(raw) + if err != nil { + return "", fmt.Errorf("parse error: %w", err) + } + + if u.Host == "" { + u.Host = "localhost" + } + + hostPart, portPart, err := net.SplitHostPort(u.Host) + if err == nil { + u.Host = net.JoinHostPort(hostPart, portPart) + } else { + u.Host = net.JoinHostPort(u.Host, "11434") + } + + if u.Path != "" { + u.Path = path.Clean(u.Path) + } + + if u.Path == "/" { + u.Path = "" + } + + return u.String(), nil +} + func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { switch detectModelTypeFromFiles(files) { case "safetensors": @@ -316,17 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) { return ggml.KV{}, fmt.Errorf("no base model was found") } -func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) { - config := ConfigV2{ - OS: "linux", - Architecture: "amd64", - RootFS: RootFS{ - Type: "layers", - }, - Renderer: r.Renderer, - Parser: r.Parser, - } - +func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { var layers []Layer for _, layer := range baseLayers { if layer.GGML != nil { @@ -406,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, return err } - configLayer, err := createConfigLayer(layers, config) + configLayer, err := createConfigLayer(layers, *config) if err != nil { return err } diff --git a/server/create_test.go b/server/create_test.go index 59a07ff14..061efb81a 100644 --- a/server/create_test.go +++ b/server/create_test.go @@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) { }) } } + +func TestRemoteURL(t *testing.T) { + tests := []struct { + name string + input string + expected string + hasError bool + }{ + { + name: "absolute path", + input: "/foo/bar", + expected: "http://localhost:11434/foo/bar", + hasError: false, + }, + { + name: "absolute path with cleanup", + input: "/foo/../bar", + expected: "http://localhost:11434/bar", + hasError: false, + }, + { + name: "root path", + input: "/", + expected: "http://localhost:11434/", + hasError: false, + }, + { + name: "host without scheme", + input: "example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "host with port", + input: "example.com:8080", + expected: "http://example.com:8080", + hasError: false, + }, + { + name: "full URL", + input: "https://example.com:8080/path", + expected: "https://example.com:8080/path", + hasError: false, + }, + { + name: "full URL with path cleanup", + input: "https://example.com:8080/path/../other", + expected: "https://example.com:8080/other", + hasError: false, + }, + { + name: "ollama.com special case", + input: "ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "http ollama.com special case", + input: "http://ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "URL with only host", + input: "http://example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "URL with root path cleaned", + input: "http://example.com/", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "invalid URL", + input: "http://[::1]:namedport", // invalid port + expected: "", + hasError: true, + }, + { + name: "empty string", + input: "", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "host with scheme but no port", + input: "http://localhost", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "complex path cleanup", + input: "/a/b/../../c/./d", + expected: "http://localhost:11434/c/d", + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := remoteURL(tt.input) + + if tt.hasError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestRemoteURL_Idempotent(t *testing.T) { + // Test that applying remoteURL twice gives the same result as applying it once + testInputs := []string{ + "/foo/bar", + "example.com", + "https://example.com:8080/path", + "ollama.com", + "http://localhost:11434", + } + + for _, input := range testInputs { + t.Run(input, func(t *testing.T) { + firstResult, err := remoteURL(input) + if err != nil { + t.Fatalf("first call failed: %v", err) + } + + secondResult, err := remoteURL(firstResult) + if err != nil { + t.Fatalf("second call failed: %v", err) + } + + if firstResult != secondResult { + t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult) + } + }) + } +} diff --git a/server/images.go b/server/images.go index 6432860f8..9466b7fb4 100644 --- a/server/images.go +++ b/server/images.go @@ -74,21 +74,29 @@ func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} // Check for completion capability - f, err := gguf.Open(m.ModelPath) - if err == nil { - defer f.Close() + if m.ModelPath != "" { + f, err := gguf.Open(m.ModelPath) + if err == nil { + defer f.Close() - if f.KeyValue("pooling_type").Valid() { - capabilities = append(capabilities, model.CapabilityEmbedding) + if f.KeyValue("pooling_type").Valid() { + capabilities = append(capabilities, model.CapabilityEmbedding) + } else { + // If no embedding is specified, we assume the model supports completion + capabilities = append(capabilities, model.CapabilityCompletion) + } + if f.KeyValue("vision.block_count").Valid() { + capabilities = append(capabilities, model.CapabilityVision) + } } else { - // If no embedding is specified, we assume the model supports completion - capabilities = append(capabilities, model.CapabilityCompletion) + slog.Error("couldn't open model file", "error", err) } - if f.KeyValue("vision.block_count").Valid() { - capabilities = append(capabilities, model.CapabilityVision) + } else if len(m.Config.Capabilities) > 0 { + for _, c := range m.Config.Capabilities { + capabilities = append(capabilities, model.Capability(c)) } } else { - slog.Error("couldn't open model file", "error", err) + slog.Warn("unknown capabilities for model", "model", m.Name) } if m.Template == nil { @@ -111,6 +119,11 @@ func (m *Model) Capabilities() []model.Capability { capabilities = append(capabilities, model.CapabilityVision) } + // Skip the thinking check if it's already set + if slices.Contains(capabilities, "thinking") { + return capabilities + } + // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) hasTags := openingTag != "" && closingTag != "" @@ -253,11 +266,20 @@ type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` ModelFamilies []string `json:"model_families"` - ModelType string `json:"model_type"` - FileType string `json:"file_type"` + ModelType string `json:"model_type"` // shown as Parameter Size + FileType string `json:"file_type"` // shown as Quantization Level Renderer string `json:"renderer,omitempty"` Parser string `json:"parser,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + + // used for remotes + Capabilities []string `json:"capabilities,omitempty"` + ContextLen int `json:"context_length,omitempty"` + EmbedLen int `json:"embedding_length,omitempty"` + BaseName string `json:"base_name,omitempty"` + // required by spec Architecture string `json:"architecture"` OS string `json:"os"` diff --git a/server/routes.go b/server/routes.go index e999c6c01..dc868038c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "net/netip" + "net/url" "os" "os/signal" "slices" @@ -28,6 +29,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" @@ -189,6 +191,84 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + + if req.Template == "" && m.Template.String() != "" { + req.Template = m.Template.String() + } + + if req.Options == nil { + req.Options = map[string]any{} + } + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + // update the system prompt from the model if one isn't already specified + if req.System == "" && m.System != "" { + req.System = m.System + } + + if len(m.Messages) > 0 { + slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead") + } + + fn := func(resp api.GenerateResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Generate(c, &req, fn) + if err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pk, pkErr := auth.GetPublicKey() + if pkErr != nil { + slog.Error("couldn't get public key", "error", pkErr) + c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + return + } + c.JSON(http.StatusUnauthorized, gin.H{"public_key": pk}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + // expire the runner if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) @@ -931,6 +1011,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ModifiedAt: manifest.fi.ModTime(), } + if m.Config.RemoteHost != "" { + resp.RemoteHost = m.Config.RemoteHost + resp.RemoteModel = m.Config.RemoteModel + + if m.Config.ModelFamily != "" { + resp.ModelInfo = make(map[string]any) + resp.ModelInfo["general.architecture"] = m.Config.ModelFamily + + if m.Config.BaseName != "" { + resp.ModelInfo["general.basename"] = m.Config.BaseName + } + + if m.Config.ContextLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen + } + + if m.Config.EmbedLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen + } + } + } + var params []string cs := 30 for k, v := range m.Options { @@ -961,6 +1063,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprint(&sb, m.String()) resp.Modelfile = sb.String() + // skip loading tensor information if this is a remote model + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + return resp, nil + } + kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) if err != nil { return nil, err @@ -1037,11 +1144,13 @@ func (s *Server) ListHandler(c *gin.Context) { // tag should never be masked models = append(models, api.ListModelResponse{ - Model: n.DisplayShortest(), - Name: n.DisplayShortest(), - Size: m.Size(), - Digest: m.digest, - ModifiedAt: m.fi.ModTime(), + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + RemoteModel: cf.RemoteModel, + RemoteHost: cf.RemoteHost, + Size: m.Size(), + Digest: m.digest, + ModifiedAt: m.fi.ModTime(), Details: api.ModelDetails{ Format: cf.ModelFormat, Family: cf.ModelFamily, @@ -1301,6 +1410,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) + r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) + r.POST("/api/me", s.WhoamiHandler) + // Create r.POST("/api/create", s.CreateHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler) @@ -1497,6 +1609,49 @@ func streamResponse(c *gin.Context, ch chan any) { }) } +func (s *Server) WhoamiHandler(c *gin.Context) { + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + user, err := client.Whoami(c) + if err != nil { + slog.Error(err.Error()) + } + c.JSON(http.StatusOK, user) +} + +func (s *Server) SignoutHandler(c *gin.Context) { + encodedKey := c.Param("encodedKey") + + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + err = client.Signout(c, encodedKey) + if err != nil { + slog.Error(err.Error()) + if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") { + c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + c.JSON(http.StatusOK, nil) +} + func (s *Server) PsHandler(c *gin.Context) { models := []api.ProcessModelResponse{} @@ -1553,21 +1708,34 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // expire the runner - if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { - model, err := GetModel(req.Model) - if err != nil { - switch { - case os.IsNotExist(err): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == errtypes.InvalidModelNameErrMsg: - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + name, err := getExistingName(name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + m, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == errtypes.InvalidModelNameErrMsg: + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - s.sched.expireRunner(model) + return + } + + // expire the runner + if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + s.sched.expireRunner(m) c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, @@ -1579,6 +1747,66 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + if req.Options == nil { + req.Options = map[string]any{} + } + + msgs := append(m.Messages, req.Messages...) + if req.Messages[0].Role != "system" && m.System != "" { + msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) + } + msgs = filterThinkTags(msgs, m) + req.Messages = msgs + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + fn := func(resp api.ChatResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Chat(c, &req, fn) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + caps := []model.Capability{model.CapabilityCompletion} if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) @@ -1587,17 +1815,6 @@ func (s *Server) ChatHandler(c *gin.Context) { caps = append(caps, model.CapabilityThinking) } - name := model.ParseName(req.Model) - if !name.IsValid() { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - name, err := getExistingName(name) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 3b3d99100..189ef0407 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -11,6 +11,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "slices" "strings" "testing" @@ -20,6 +21,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/types/model" ) var stream bool = false @@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) { }) } +func TestCreateAndShowRemoteModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + var s Server + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test", + From: "bob", + RemoteHost: "https://ollama.com", + Info: map[string]any{ + "capabilities": []string{"completion", "tools", "thinking"}, + "model_family": "gptoss", + "context_length": 131072, + "embedding_length": 2880, + "quantization_level": "MXFP4", + "parameter_size": "20.9B", + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"}) + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + var resp api.ShowResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + expectedDetails := api.ModelDetails{ + ParentModel: "", + Format: "", + Family: "gptoss", + Families: []string{"gptoss"}, + ParameterSize: "20.9B", + QuantizationLevel: "MXFP4", + } + + if !reflect.DeepEqual(resp.Details, expectedDetails) { + t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details) + } + + expectedCaps := []model.Capability{ + model.Capability("completion"), + model.Capability("tools"), + model.Capability("thinking"), + } + + if !slices.Equal(resp.Capabilities, expectedCaps) { + t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities) + } + + v, ok := resp.ModelInfo["gptoss.context_length"] + ctxlen := v.(float64) + if !ok || int(ctxlen) != 131072 { + t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen)) + } + + v, ok = resp.ModelInfo["gptoss.embedding_length"] + embedlen := v.(float64) + if !ok || int(embedlen) != 2880 { + t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen)) + } + + fmt.Printf("resp = %#v\n", resp) +} + func TestCreateLicenses(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_test.go b/server/routes_test.go index 87b526633..bb7e2b7c1 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) { t.Fatalf("failed to create model: %v", err) } - if err := createModel(r, modelName, baseLayers, fn); err != nil { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + + if err := createModel(r, modelName, baseLayers, config, fn); err != nil { t.Fatal(err) } } From 9b8187b487159cda3e753f9d242303d857ba321c Mon Sep 17 00:00:00 2001 From: frob Date: Thu, 18 Sep 2025 01:39:04 +0200 Subject: [PATCH 31/32] server: skip parsing initial if provided in the prompt for /api/generate (#12289) --- server/routes.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/routes.go b/server/routes.go index dc868038c..b1def0dea 100644 --- a/server/routes.go +++ b/server/routes.go @@ -429,6 +429,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { OpeningTag: openingTag, ClosingTag: closingTag, } + if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) { + thinkingState.AddContent(openingTag) + } } } From 2717dce6fe1bb4eab80abd5fbbd713211a7fc276 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 17 Sep 2025 17:43:17 -0700 Subject: [PATCH 32/32] convert: convert bf16 vision weights to fp16 (#12324) This change moves back to converting bf16 vision weights to fp16, specifically if they start with the name "v." (such as v.blk.0.attn_k.weight). This fixes a bug where converted images are failing because they are trying to call `im2col` which doesn't have a bf16 kernel in ggml. --- convert/reader_safetensors.go | 2 +- convert/reader_test.go | 62 +++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index 7f029f933..eea0de2f5 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -96,7 +96,7 @@ type safetensor struct { func (st safetensor) Kind() uint32 { kind := st.tensorBase.Kind() - if st.dtype == "BF16" && kind != tensorKindFP32 { + if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 { kind = tensorKindBF16 } diff --git a/convert/reader_test.go b/convert/reader_test.go index 6dbe32a51..c3d094f10 100644 --- a/convert/reader_test.go +++ b/convert/reader_test.go @@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) { }) } } + +func TestSafetensorKind(t *testing.T) { + tests := []struct { + name string + st safetensor + expected uint32 + }{ + { + name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "BF16", + }, + expected: tensorKindBF16, + }, + { + name: "BF16 dtype with v. prefix should return base kind", + st: safetensor{ + tensorBase: &tensorBase{ + name: "v.weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "BF16", + }, + expected: tensorKindFP16, + }, + { + name: "BF16 dtype with FP32 base kind should return FP32", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10}, // will default to FP32 + }, + dtype: "BF16", + }, + expected: tensorKindFP32, + }, + { + name: "Non-BF16 dtype should return base kind", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "FP16", + }, + expected: tensorKindFP16, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.st.Kind() + if result != tt.expected { + t.Errorf("Kind() = %d, expected %d", result, tt.expected) + } + }) + } +}