diff --git a/CMakeLists.txt b/CMakeLists.txt
index 92b1793b6..034fc7d79 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -106,9 +106,11 @@ if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32)
- target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
+ target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
endif()
+ target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
+
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES
diff --git a/Dockerfile b/Dockerfile
index 46d4713e7..4136fca71 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
- && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
+ && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
@@ -86,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS build
-ARG GOVERSION=1.23.4
-RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
-ENV PATH=/usr/local/go/bin:$PATH
WORKDIR /go/src/github.com/ollama/ollama
+COPY go.mod go.sum .
+RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
+ENV PATH=/usr/local/go/bin:$PATH
+RUN go mod download
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
diff --git a/README.md b/README.md
index 8d471f5ef..b4df5e2af 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
@@ -54,6 +54,7 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- |
+| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
@@ -64,7 +65,7 @@ Here are some example models that can be downloaded:
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
-| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
+| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
@@ -75,7 +76,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
-| Solar | 10.7B | 6.1GB | `ollama run solar` |
+| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -275,6 +276,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui)
+- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
@@ -387,6 +389,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
+- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
+- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Cloud
@@ -430,6 +434,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Apple Vision Pro
+- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
- [Enchanted](https://github.com/AugustDev/enchanted)
### Database
@@ -507,10 +512,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Mobile
+- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
+- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
+- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Extensions & Plugins
@@ -556,12 +564,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
+- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability
+- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
diff --git a/api/types.go b/api/types.go
index 14bb7eab4..bc5aad8f1 100644
--- a/api/types.go
+++ b/api/types.go
@@ -403,9 +403,9 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct {
Model string `json:"model"`
- Insecure bool `json:"insecure,omitempty"`
- Username string `json:"username"`
- Password string `json:"password"`
+ Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
+ Username string `json:"username"` // Deprecated: ignored
+ Password string `json:"password"` // Deprecated: ignored
Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
diff --git a/cmd/cmd.go b/cmd/cmd.go
index 80ece4c60..c22a08f43 100644
--- a/cmd/cmd.go
+++ b/cmd/cmd.go
@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
- "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/runner"
@@ -256,6 +255,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
}
+ return err
}
return nil
}
@@ -338,10 +338,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
- // TODO(jessegross): We should either find another way to know if this is
- // a vision model or remove the logic. Also consider that other modalities will
- // need different behavior anyways.
- opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
+ if len(info.ProjectorInfo) != 0 {
+ opts.MultiModal = true
+ }
+ for k := range info.ModelInfo {
+ if strings.Contains(k, ".vision.") {
+ opts.MultiModal = true
+ break
+ }
+ }
+
opts.ParentModel = info.Details.ParentModel
if interactive {
@@ -1274,7 +1280,6 @@ func NewCLI() *cobra.Command {
runnerCmd := &cobra.Command{
Use: "runner",
- Short: llama.PrintSystemInfo(),
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
return runner.Execute(os.Args[1:])
diff --git a/docs/development.md b/docs/development.md
index eb67dbfaf..cf6d91e27 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -118,6 +118,35 @@ To run tests, use `go test`:
go test ./...
```
+> NOTE: In rare cirumstances, you may nedd to change a package using the new
+> "synctest" package in go1.24.
+>
+> If you do not have the "synctest" package enabled, you will not see build or
+> test failures resulting from your change(s), if any, locally, but CI will
+> break.
+>
+> If you see failures in CI, you can either keep pushing changes to see if the
+> CI build passes, or you can enable the "synctest" package locally to see the
+> failures before pushing.
+>
+> To enable the "synctest" package for testing, run the following command:
+>
+> ```shell
+> GOEXPERIMENT=synctest go test ./...
+> ```
+>
+> If you wish to enable synctest for all go commands, you can set the
+> `GOEXPERIMENT` environment variable in your shell profile or by using:
+>
+> ```shell
+> go env -w GOEXPERIMENT=synctest
+> ```
+>
+> Which will enable the "synctest" package for all go commands without needing
+> to set it for all shell sessions.
+>
+> The synctest package is not required for production builds.
+
## Library detection
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
diff --git a/docs/faq.md b/docs/faq.md
index 04e8433de..4aaccc2e4 100644
--- a/docs/faq.md
+++ b/docs/faq.md
@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
-By default, Ollama uses a context window size of 2048 tokens.
+By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
To change this when using `ollama run`, use `/set parameter`:
diff --git a/docs/linux.md b/docs/linux.md
index 12581bdd2..2dda87f3a 100644
--- a/docs/linux.md
+++ b/docs/linux.md
@@ -75,7 +75,7 @@ RestartSec=3
Environment="PATH=$PATH"
[Install]
-WantedBy=default.target
+WantedBy=multi-user.target
```
Then start the service:
diff --git a/docs/windows.md b/docs/windows.md
index 018cc41d0..78b99a5d7 100644
--- a/docs/windows.md
+++ b/docs/windows.md
@@ -81,9 +81,11 @@ help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
-and GPU library dependencies for Nvidia and AMD. This allows for embedding
-Ollama in existing applications, or running it as a system service via `ollama
-serve` with tools such as [NSSM](https://nssm.cc/).
+and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
+and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
+same directory. This allows for embedding Ollama in existing applications, or
+running it as a system service via `ollama serve` with tools such as
+[NSSM](https://nssm.cc/).
> [!NOTE]
> If you are upgrading from a prior version, you should remove the old directories first.
diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go
index b9f9cc178..8662c3b01 100644
--- a/fs/ggml/ggml.go
+++ b/fs/ggml/ggml.go
@@ -565,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
return
}
+func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
+ switch llm.KV().Architecture() {
+ case "mllama":
+ for _, layer := range llm.Tensors().GroupLayers()["v"] {
+ weights += layer.Size()
+ }
+
+ kv := func(n string) uint64 {
+ if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
+ return uint64(v)
+ }
+
+ return 0
+ }
+
+ imageSize := kv("image_size")
+
+ maxNumTiles := kv("max_num_tiles")
+ embeddingLength := kv("embedding_length")
+ headCount := kv("attention.head_count")
+
+ numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
+ if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
+ numPatches++
+ }
+
+ numPaddedPatches := numPatches + 8 - (numPatches%8)%8
+
+ graphSize = 4 * (8 +
+ imageSize*imageSize*kv("num_channels")*maxNumTiles +
+ embeddingLength*numPatches*maxNumTiles +
+ 9*embeddingLength*numPaddedPatches*maxNumTiles +
+ numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
+ }
+ return weights, graphSize
+}
+
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
diff --git a/go.mod b/go.mod
index af0cedc86..cc5789005 100644
--- a/go.mod
+++ b/go.mod
@@ -24,7 +24,7 @@ require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0
- gonum.org/v1/gonum v0.15.0
+ golang.org/x/tools v0.30.0
)
require (
@@ -44,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+ gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)
diff --git a/go.sum b/go.sum
index 013a7db71..0ab97b909 100644
--- a/go.sum
+++ b/go.sum
@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
+golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
+golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/kvcache/cache.go b/kvcache/cache.go
index 5d8b2f9b5..d35489057 100644
--- a/kvcache/cache.go
+++ b/kvcache/cache.go
@@ -4,6 +4,7 @@ import (
"errors"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
)
var (
@@ -29,6 +30,17 @@ type Cache interface {
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
+ // SetConfig controls optimizations (mostly backend-specific) that may transform
+ // the output of the cache to work better with specific kernels. If not called,
+ // the backend settings will be used. This works well when calling Attention.
+ //
+ // The config can be overridden by models, especially if they require vanilla
+ // output when implementing their own version of attention. To do this, pass
+ // an empty ml.CacheConfig.
+ //
+ // Most models will not need to use this.
+ SetConfig(ml.CacheConfig)
+
// ** cache management **
// Init sets up runtime parameters
@@ -40,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
- StartForward(ctx ml.Context, positions []int32, seqs []int) error
+ StartForward(ctx ml.Context, opts input.Options) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
diff --git a/kvcache/causal.go b/kvcache/causal.go
index 69068439e..34d5337cf 100644
--- a/kvcache/causal.go
+++ b/kvcache/causal.go
@@ -8,6 +8,7 @@ import (
"slices"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
)
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@@ -20,8 +21,12 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct {
DType ml.DType
Capacity int32
+ causal bool
windowSize int32
+ // config controls mostly backend-specific optimizations
+ config *ml.CacheConfig
+
// ** current forward pass **
// the active layer for Get and Put
@@ -39,6 +44,12 @@ type Causal struct {
// locations in the cache that are needed for this batch
curCellRange cellRange
+ // curSequences is the sequences corresponding to this pass's entries in the cache
+ curSequences []int
+
+ // curPositions is the positions corresponding to this pass's entries in the cache
+ curPositions []int32
+
// ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences
@@ -52,8 +63,8 @@ type Causal struct {
shiftFn shiftFn
backend ml.Backend
- cacheCtx ml.Context
- keys, values []ml.Tensor
+ ctxs map[int]ml.Context
+ keys, values map[int]ml.Tensor
}
type cacheCell struct {
@@ -67,28 +78,73 @@ type cellRange struct {
}
func NewCausalCache(shift shiftFn) *Causal {
- return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
+ return &Causal{
+ causal: true,
+ windowSize: math.MaxInt32,
+ shiftFn: shift,
+ ctxs: make(map[int]ml.Context),
+ keys: make(map[int]ml.Tensor),
+ values: make(map[int]ml.Tensor),
+ }
}
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
- return &Causal{windowSize: windowSize, shiftFn: shift}
+ return &Causal{
+ causal: true,
+ windowSize: windowSize,
+ shiftFn: shift,
+ ctxs: make(map[int]ml.Context),
+ keys: make(map[int]ml.Tensor),
+ values: make(map[int]ml.Tensor),
+ }
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
+ if c.config == nil {
+ var config ml.CacheConfig
+ if cc, ok := backend.(ml.BackendCacheConfig); ok {
+ config = cc.CacheConfig()
+ }
+ c.config = &config
+ }
+
+ if c.config.CachePadding == 0 {
+ c.config.CachePadding = 1
+ }
+
+ if c.config.MaskBatchPadding == 0 {
+ c.config.MaskBatchPadding = 1
+ }
+
+ if c.config.MaskDType == ml.DTypeOther {
+ c.config.MaskDType = ml.DTypeF32
+ }
+
c.DType = dtype
- c.Capacity = capacity
- c.cells = make([]cacheCell, capacity)
+ c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
+ c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
- c.cacheCtx = backend.NewContext()
+}
+
+func (c *Causal) SetConfig(config ml.CacheConfig) {
+ if c.config != nil {
+ panic("config cannot be changed after being previously set, either by the model or backend")
+ }
+
+ c.config = &config
}
func (c *Causal) Close() {
- c.cacheCtx.Close()
+ for _, ctx := range c.ctxs {
+ ctx.Close()
+ }
}
-func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
- c.curBatchSize = len(positions)
+func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
+ c.curBatchSize = len(opts.Positions)
+ c.curSequences = opts.Sequences
+ c.curPositions = opts.Positions
var err error
c.curLoc, err = c.findStartLoc()
@@ -101,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
}
c.curCellRange = newRange()
- for i, pos := range positions {
- seq := seqs[i]
+ for i, pos := range opts.Positions {
+ seq := opts.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -127,7 +183,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
c.cellRanges[seq] = seqRange
}
- c.curMask, err = c.buildMask(ctx, positions, seqs)
+ c.curMask, err = c.buildMask(ctx)
return err
}
@@ -157,36 +213,90 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
+func roundDown(length, pad int) int {
+ return (length / pad) * pad
+}
+
+func roundUp(length, pad int) int {
+ return ((length + pad - 1) / pad) * pad
+}
+
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
-func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
- // TODO(jessegross): This does not do padding, which is required for flash attention
- len := c.curCellRange.max - c.curCellRange.min + 1
- mask := make([]float32, c.curBatchSize*len)
+func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
+ // Align and pad the two dimensions as required by the backend
+ batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
+
+ c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
+ c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
+
+ length := c.curCellRange.max - c.curCellRange.min + 1
+ mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
- if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
- c.cells[j].pos < positions[i]-c.windowSize {
- mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
+ if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
+ (c.causal && c.cells[j].pos > c.curPositions[i]) ||
+ c.cells[j].pos < c.curPositions[i]-c.windowSize {
+ mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
- return ctx.FromFloatSlice(mask, len, c.curBatchSize)
+ // Mask out any padding tokens we added. For padding that we added to the cache history, this
+ // has already been masked out because the sequence doesn't match.
+ for i := c.curBatchSize * length; i < len(mask); i++ {
+ mask[i] = float32(math.Inf(-1))
+ }
+
+ maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
+ if err != nil {
+ return nil, err
+ }
+
+ if c.config.MaskDType != ml.DTypeF32 {
+ out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
+ ctx.Forward(maskTensor.Copy(ctx, out))
+ maskTensor = out
+ }
+
+ return maskTensor, nil
}
-func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
- for _, obj := range objs {
- if obj == nil {
+func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
+ for i, key := range c.keys {
+ if key == nil {
continue
}
- srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
- dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
+ kHeadDim := key.Dim(0)
+ numKVHeads := key.Dim(1)
+ rowSize := key.Stride(2)
- ctx.Forward(srcView.Copy(ctx, dstView))
+ kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
+ kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
+
+ value := c.values[i]
+ var vSrcView, vDstView ml.Tensor
+ if c.config.PermutedV {
+ vHeadDim := value.Dim(1)
+ elemSize := value.Stride(0)
+
+ vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
+ vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
+ } else {
+ vHeadDim := value.Dim(0)
+ rowSize := value.Stride(2)
+
+ vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
+ vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
+ }
+
+ ctx.Forward(
+ kSrcView.Copy(ctx, kDstView),
+ vSrcView.Copy(ctx, vDstView),
+ )
}
}
@@ -219,7 +329,7 @@ func (c *Causal) defrag() {
layers++
}
- maxMoves := ctx.MaxTensors() / (6 * layers)
+ maxMoves := ctx.MaxGraphNodes() / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -238,8 +348,7 @@ func (c *Causal) defrag() {
pendingLen++
break
} else {
- moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
- moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
+ c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
@@ -263,8 +372,7 @@ func (c *Causal) defrag() {
}
if pendingLen > 0 {
- moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
- moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
+ c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
@@ -293,47 +401,106 @@ func (c *Causal) defrag() {
}
func (c *Causal) SetLayer(layer int) {
- if layer >= len(c.keys) {
- c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
- c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
- }
-
c.curLayer = layer
}
+// SetCausal enables or disables causal mask generation for subsequent calls to Get.
+// This state carries over to future forward passes. The default value is true.
+//
+// ctx may be set to nil if this is called from outside of a forward pass, for
+// example, when initializing the cache.
+func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
+ if c.causal != causal {
+ c.causal = causal
+
+ if ctx != nil {
+ var err error
+ c.curMask, err = c.buildMask(ctx)
+ if err != nil {
+ // This error should never occur because we have previously built a mask with the same shape
+ panic(fmt.Errorf("SetCausal: %w", err))
+ }
+ }
+ }
+}
+
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
- key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
- key.Dim(0), key.Stride(1),
- key.Dim(1), key.Stride(2),
- c.curMask.Dim(0),
+ kHeadDim := key.Dim(0)
+ numKVHeads := key.Dim(1)
+ rowSize := key.Stride(2)
+ cachedSize := c.curMask.Dim(0)
+
+ key = key.View(ctx, rowSize*c.curCellRange.min,
+ kHeadDim, key.Stride(1),
+ numKVHeads, key.Stride(2),
+ cachedSize,
)
- value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
- value.Dim(0), value.Stride(1),
- value.Dim(1), value.Stride(2),
- c.curMask.Dim(0),
- )
+ if c.config.PermutedV {
+ vHeadDim := value.Dim(1)
+ elemSize := value.Stride(0)
+
+ value = value.View(ctx, elemSize*c.curCellRange.min,
+ cachedSize, value.Stride(1),
+ vHeadDim, value.Stride(2),
+ numKVHeads,
+ )
+ } else {
+ vHeadDim := value.Dim(0)
+ rowSize := value.Stride(2)
+
+ value = value.View(ctx, rowSize*c.curCellRange.min,
+ vHeadDim, value.Stride(1),
+ numKVHeads, value.Stride(2),
+ cachedSize,
+ )
+ }
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
- if c.curBatchSize != key.Dim(2) {
- panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
+ kHeadDim := key.Dim(0)
+ vHeadDim := value.Dim(0)
+ numKVHeads := key.Dim(1)
+ batchSize := key.Dim(2)
+
+ if c.curBatchSize != batchSize {
+ panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
}
- if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
- c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
- c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
+ if _, ok := c.ctxs[c.curLayer]; !ok {
+ c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
- ctx.Forward(
- key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
- value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
- )
+ if _, ok := c.keys[c.curLayer]; !ok {
+ c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
+ }
+
+ if _, ok := c.values[c.curLayer]; !ok {
+ if c.config.PermutedV {
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
+ } else {
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
+ }
+ }
+
+ rowSize := c.keys[c.curLayer].Stride(2)
+ ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
+
+ if c.config.PermutedV {
+ elemSize := c.values[c.curLayer].Stride(0)
+
+ value = value.Permute(ctx, 1, 2, 0, 3)
+ ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
+ } else {
+ rowSize := c.values[c.curLayer].Stride(2)
+
+ ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
+ }
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -379,7 +546,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
}
}
- kShift, err := ctx.FromIntSlice(offsets, len(offsets))
+ kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
@@ -389,9 +556,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue
}
- key = key.View(ctx, key.Stride(2)*seqRange.min,
- key.Dim(0), key.Stride(1),
- key.Dim(1), key.Stride(2),
+ kHeadDim := key.Dim(0)
+ numKVHeads := key.Dim(1)
+ rowSize := key.Stride(2)
+
+ key = key.View(ctx, rowSize*seqRange.min,
+ kHeadDim, key.Stride(1),
+ numKVHeads, key.Stride(2),
size,
)
diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go
index bd7d0ae8b..22d8efb43 100644
--- a/kvcache/causal_test.go
+++ b/kvcache/causal_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
)
type testCase struct {
@@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
- err := cache.StartForward(context, test.pos, test.seqs)
+ err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}
@@ -303,13 +304,17 @@ func (b *testBackend) NewContext() ml.Context {
return &testContext{}
}
+func (b *testBackend) NewContextSize(int) ml.Context {
+ return &testContext{}
+}
+
func (b *testBackend) SystemInfo() string {
return "not implemented"
}
type testContext struct{}
-func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
+func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
@@ -322,8 +327,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
+func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
+ return c.Empty(dtype, shape...)
+}
+
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
- t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
+ t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
@@ -342,11 +351,15 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil
}
+func (c *testContext) Input() ml.Context { return c }
+func (c *testContext) Output() ml.Context { return c }
+func (c *testContext) Layer(int) ml.Context { return c }
+
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
-func (c *testContext) MaxTensors() int {
+func (c *testContext) MaxGraphNodes() int {
return 10
}
@@ -391,7 +404,7 @@ func (t *testTensor) Floats() []float32 {
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
- out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
+ out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -468,7 +481,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{}
- view := context.Zeros(t.dtype, s...).(*testTensor)
+ view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view
diff --git a/kvcache/encoder.go b/kvcache/encoder.go
index b85b1046a..6a9df2abc 100644
--- a/kvcache/encoder.go
+++ b/kvcache/encoder.go
@@ -1,7 +1,10 @@
package kvcache
import (
+ "fmt"
+
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
)
// Encoder cache stores K and V tensors that are position independent
@@ -11,6 +14,9 @@ import (
//
// Not currently safe for multiple sequences
type EncoderCache struct {
+ // config controls mostly backend-specific optimizations
+ config *ml.CacheConfig
+
// ** current forward pass **
// the active layer for Get and Put
@@ -30,36 +36,59 @@ type EncoderCache struct {
encoderPos int32
// ** cache data storage **
-
- cacheCtx ml.Context
- keys, values []ml.Tensor
+ backend ml.Backend
+ ctxs map[int]ml.Context
+ keys, values map[int]ml.Tensor
}
func NewEncoderCache() *EncoderCache {
- return &EncoderCache{}
+ return &EncoderCache{
+ ctxs: make(map[int]ml.Context),
+ keys: make(map[int]ml.Tensor),
+ values: make(map[int]ml.Tensor),
+ }
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
- c.cacheCtx = backend.NewContext()
+ if c.config == nil {
+ var config ml.CacheConfig
+ if cc, ok := backend.(ml.BackendCacheConfig); ok {
+ config = cc.CacheConfig()
+ }
+ c.config = &config
+ }
+
+ if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
+ panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
+ }
+
+ c.backend = backend
+}
+
+func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
+ if c.config != nil {
+ panic("config cannot be changed after being previously set, either by the model or backend")
+ }
+
+ c.config = &config
}
func (c *EncoderCache) Close() {
- c.cacheCtx.Close()
+ for _, ctx := range c.ctxs {
+ ctx.Close()
+ }
}
-func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
- // The image is always in the first position
- c.curPos = positions[0]
+func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
+ // We work with the most recent image
+ if len(opts.Multimodal) > 0 {
+ c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
+ }
return nil
}
func (c *EncoderCache) SetLayer(layer int) {
- if layer >= len(c.keys) {
- c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
- c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
- }
-
c.curLayer = layer
}
@@ -75,9 +104,20 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
- if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
- c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
- c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
+ if c.config.PermutedV {
+ value = value.Permute(ctx, 1, 2, 0, 3)
+ }
+
+ if _, ok := c.ctxs[c.curLayer]; !ok {
+ c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
+ }
+
+ if _, ok := c.keys[c.curLayer]; !ok {
+ c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
+ }
+
+ if _, ok := c.values[c.curLayer]; !ok {
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
}
ctx.Forward(
diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go
index 2d4c1089a..aaccd1661 100644
--- a/kvcache/wrapper.go
+++ b/kvcache/wrapper.go
@@ -4,6 +4,7 @@ import (
"math"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
)
// Wrapper cache is a container for multiple types of caches,
@@ -28,20 +29,26 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
}
}
+func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
+ for _, cache := range c.caches {
+ cache.SetConfig(config)
+ }
+}
+
func (c *WrapperCache) Close() {
for _, cache := range c.caches {
cache.Close()
}
}
-func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
+func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
for i, cache := range c.caches {
- err := cache.StartForward(ctx, positions, seqs)
+ err := cache.StartForward(ctx, opts)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
- for k := range positions {
- _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
+ for k := range opts.Positions {
+ _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
}
}
return err
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
index c7ff28be1..7a185443a 100644
--- a/llama/llama.cpp/src/llama-vocab.cpp
+++ b/llama/llama.cpp/src/llama-vocab.cpp
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN
diff --git a/llama/llama.go b/llama/llama.go
index 0c4fca430..a026bee24 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -21,18 +21,6 @@ package llama
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
-
-typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
-COMPILER inline get_compiler() {
-#if defined(__clang__)
- return COMP_CLANG;
-#elif defined(__GNUC__)
- return COMP_GCC;
-#else
- return UNKNOWN_COMPILER;
-#endif
-}
-
*/
import "C"
@@ -72,19 +60,6 @@ func BackendInit() {
C.llama_backend_init()
}
-func PrintSystemInfo() string {
- var compiler string
- switch C.get_compiler() {
- case C.COMP_UNKNOWN:
- compiler = "cgo(unknown_compiler)"
- case C.COMP_GCC:
- compiler = "cgo(gcc)"
- case C.COMP_CLANG:
- compiler = "cgo(clang)"
- }
- return C.GoString(C.llama_print_system_info()) + compiler
-}
-
func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
@@ -270,6 +245,20 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
return &m, nil
}
+func LoadVocabFromFile(path string) (*Vocab, error) {
+ mp := C.CString(path)
+ defer C.free(unsafe.Pointer(mp))
+ v := Vocab{c: C.llama_load_vocab_from_file(mp)}
+ if v.c == nil {
+ return nil, fmt.Errorf("unable to load vocab: %s", path)
+ }
+ return &v, nil
+}
+
+func FreeVocab(vocab *Vocab) {
+ C.llama_free_vocab(vocab.c)
+}
+
func FreeModel(model *Model) {
C.llama_model_free(model.c)
}
@@ -318,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
return nil
}
+type Vocab struct {
+ c *C.struct_llama_vocab
+}
+
func (m *Model) Vocab() *C.struct_llama_vocab {
return C.llama_model_get_vocab(m.c)
}
@@ -694,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte {
}
return buf[:n]
}
+
+type Sampler struct {
+ c *C.struct_llama_sampler
+}
+
+func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
+ cGrammar := C.CString(grammar)
+ cRoot := C.CString("root")
+ defer C.free(unsafe.Pointer(cGrammar))
+ defer C.free(unsafe.Pointer(cRoot))
+
+ sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
+
+ return sampler
+}
+
+func (s *Sampler) Accept(token int32) {
+ C.llama_sampler_accept(s.c, C.llama_token(token))
+}
+
+type TokenData struct {
+ Id int32
+ Logit float32
+}
+
+func (s *Sampler) Apply(tokens []TokenData) {
+ tds := make([]C.struct_llama_token_data, len(tokens))
+ for i, token := range tokens {
+ tds[i] = C.struct_llama_token_data{
+ id: C.int32_t(token.Id),
+ logit: C.float(token.Logit),
+ p: C.float(0.0),
+ }
+ }
+ tda := &C.llama_token_data_array{
+ data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
+ size: C.size_t(len(tokens)),
+ selected: C.int64_t(-1),
+ sorted: C.bool(false),
+ }
+
+ var pinner runtime.Pinner
+ pinner.Pin(&tds[0])
+ defer pinner.Unpin()
+
+ C.llama_sampler_apply(s.c, tda)
+ for i := range tokens {
+ tokens[i].Logit = float32(tds[i].logit)
+ }
+}
diff --git a/llama/patches/0015-try-catch-backend-load.patch b/llama/patches/0015-try-catch-backend-load.patch
deleted file mode 100644
index 9aea61836..000000000
--- a/llama/patches/0015-try-catch-backend-load.patch
+++ /dev/null
@@ -1,69 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Michael Yang
-Date: Tue, 11 Feb 2025 14:06:36 -0800
-Subject: [PATCH] try/catch backend load
-
----
- ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
- 1 file changed, 23 insertions(+), 22 deletions(-)
-
-diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 98d5e14d..1c19129a 100644
---- a/ggml/src/ggml-backend-reg.cpp
-+++ b/ggml/src/ggml-backend-reg.cpp
-@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
- }
- fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
- for (const auto & entry : dir_it) {
-- if (entry.is_regular_file()) {
-- std::wstring filename = entry.path().filename().wstring();
-- std::wstring ext = entry.path().extension().wstring();
-- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
-- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
-- if (!handle && !silent) {
-- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-- }
-- if (handle) {
-+ try {
-+ if (entry.is_regular_file()) {
-+ std::wstring filename = entry.path().filename().wstring();
-+ std::wstring ext = entry.path().extension().wstring();
-+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
-+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
-+ if (!handle) {
-+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-+ continue;
-+ }
-+
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
-- if (score_fn) {
-- int s = score_fn();
--#ifndef NDEBUG
-- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
--#endif
-- if (s > best_score) {
-- best_score = s;
-- best_path = entry.path().wstring();
-- }
-- } else {
-- if (!silent) {
-- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-- }
-+ if (!score_fn) {
-+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-+ continue;
-+ }
-+
-+ int s = score_fn();
-+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-+ if (s > best_score) {
-+ best_score = s;
-+ best_path = entry.path().wstring();
- }
- }
- }
-+ } catch (const std::exception & e) {
-+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
- }
- }
- }
diff --git a/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch
similarity index 67%
rename from llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
rename to llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch
index d60066c13..e72d78ac8 100644
--- a/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
+++ b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch
@@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
Subject: [PATCH] use std::filesystem::path instead of wstring
---
- ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
- 1 file changed, 58 insertions(+), 86 deletions(-)
+ ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++-------------------
+ 1 file changed, 88 insertions(+), 111 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 1c19129a..c854e6bb 100644
+index 98d5e14d..799af5f3 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -66,26 +66,6 @@
@@ -264,47 +264,55 @@ index 1c19129a..c854e6bb 100644
for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) {
continue;
-@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
+@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
+ fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
- try {
- if (entry.is_regular_file()) {
-- std::wstring filename = entry.path().filename().wstring();
-- std::wstring ext = entry.path().extension().wstring();
-+ std::string filename = entry.path().filename().string();
-+ std::string ext = entry.path().extension().string();
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
-- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
-+ dl_handle_ptr handle { dl_load_library(entry.path()) };
- if (!handle) {
-- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
- continue;
- }
-
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (!score_fn) {
-- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
- continue;
- }
-
- int s = score_fn();
-- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
- if (s > best_score) {
- best_score = s;
-- best_path = entry.path().wstring();
-+ best_path = entry.path();
- }
+ if (entry.is_regular_file()) {
+- std::wstring filename = entry.path().filename().wstring();
+- std::wstring ext = entry.path().extension().wstring();
++ std::string filename = entry.path().filename().string();
++ std::string ext = entry.path().extension().string();
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+- if (!handle && !silent) {
+- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
++ dl_handle_ptr handle { dl_load_library(entry.path()) };
++ if (!handle) {
++ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
++ continue;
+ }
+- if (handle) {
+- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+- if (score_fn) {
+- int s = score_fn();
+-#ifndef NDEBUG
+- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+-#endif
+- if (s > best_score) {
+- best_score = s;
+- best_path = entry.path().wstring();
+- }
+- } else {
+- if (!silent) {
+- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+- }
+- }
++
++ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
++ if (!score_fn) {
++ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
++ continue;
++ }
++
++ int s = score_fn();
++ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
++ if (s > best_score) {
++ best_score = s;
++ best_path = entry.path();
}
}
- } catch (const std::exception & e) {
-- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
-+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
- }
- }
-@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
+@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) {
// try to load the base backend
for (const auto & search_path : search_paths) {
@@ -313,3 +321,49 @@ index 1c19129a..c854e6bb 100644
if (fs::exists(path)) {
return get_reg().load_backend(path, silent);
}
+@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
+ ggml_backend_load_all_from_path(nullptr);
+ }
+
++static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
++ try {
++ ggml_backend_load_best(name, silent, user_search_path);
++ } catch (const std::exception & e) {
++ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
++ }
++}
++
+ void ggml_backend_load_all_from_path(const char * dir_path) {
+ #ifdef NDEBUG
+ bool silent = true;
+@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
+ bool silent = false;
+ #endif
+
+- ggml_backend_load_best("blas", silent, dir_path);
+- ggml_backend_load_best("cann", silent, dir_path);
+- ggml_backend_load_best("cuda", silent, dir_path);
+- ggml_backend_load_best("hip", silent, dir_path);
+- ggml_backend_load_best("kompute", silent, dir_path);
+- ggml_backend_load_best("metal", silent, dir_path);
+- ggml_backend_load_best("rpc", silent, dir_path);
+- ggml_backend_load_best("sycl", silent, dir_path);
+- ggml_backend_load_best("vulkan", silent, dir_path);
+- ggml_backend_load_best("opencl", silent, dir_path);
+- ggml_backend_load_best("musa", silent, dir_path);
+- ggml_backend_load_best("cpu", silent, dir_path);
++ ggml_backend_try_load_best("blas", silent, dir_path);
++ ggml_backend_try_load_best("cann", silent, dir_path);
++ ggml_backend_try_load_best("cuda", silent, dir_path);
++ ggml_backend_try_load_best("hip", silent, dir_path);
++ ggml_backend_try_load_best("kompute", silent, dir_path);
++ ggml_backend_try_load_best("metal", silent, dir_path);
++ ggml_backend_try_load_best("rpc", silent, dir_path);
++ ggml_backend_try_load_best("sycl", silent, dir_path);
++ ggml_backend_try_load_best("vulkan", silent, dir_path);
++ ggml_backend_try_load_best("opencl", silent, dir_path);
++ ggml_backend_try_load_best("musa", silent, dir_path);
++ ggml_backend_try_load_best("cpu", silent, dir_path);
+ // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
+ const char * backend_path = std::getenv("GGML_BACKEND_PATH");
+ if (backend_path) {
diff --git a/llama/patches/0017-remove-amx.patch b/llama/patches/0016-remove-amx.patch
similarity index 100%
rename from llama/patches/0017-remove-amx.patch
rename to llama/patches/0016-remove-amx.patch
diff --git a/llama/patches/0018-fix-clip-compiler-error.patch b/llama/patches/0017-fix-clip-compiler-error.patch
similarity index 100%
rename from llama/patches/0018-fix-clip-compiler-error.patch
rename to llama/patches/0017-fix-clip-compiler-error.patch
diff --git a/llama/patches/0019-add-phi4-support.patch b/llama/patches/0018-add-phi4-support.patch
similarity index 100%
rename from llama/patches/0019-add-phi4-support.patch
rename to llama/patches/0018-add-phi4-support.patch
diff --git a/llama/patches/0019-fix-string-arr-kv-loading.patch b/llama/patches/0019-fix-string-arr-kv-loading.patch
new file mode 100644
index 000000000..aa7b4d3cf
--- /dev/null
+++ b/llama/patches/0019-fix-string-arr-kv-loading.patch
@@ -0,0 +1,64 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca
+Date: Wed, 5 Mar 2025 17:41:07 -0800
+Subject: [PATCH] fix string arr kv loading
+
+---
+ ggml/include/gguf.h | 1 +
+ ggml/src/gguf.cpp | 7 +++++--
+ src/llama-vocab.cpp | 2 +-
+ 3 files changed, 7 insertions(+), 3 deletions(-)
+
+diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
+index 79ee2020..3efb22f0 100644
+--- a/ggml/include/gguf.h
++++ b/ggml/include/gguf.h
+@@ -114,6 +114,7 @@ extern "C" {
+ // get raw pointer to the first element of the array with the given key_id
+ // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
+ GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
++ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
+
+ // get ith C string from array with given key_id
+ GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
+diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
+index ab13669c..f75b923f 100644
+--- a/ggml/src/gguf.cpp
++++ b/ggml/src/gguf.cpp
+@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
+
+ const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+ return ctx->kv[key_id].data.data();
+ }
+
++size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
++ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
++ return ctx->kv[key_id].data.size();
++}
++
+ const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
+@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
+ const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+ return ctx->kv[key_id].data.data();
+ }
+
+diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
+index c7ff28be..7a185443 100644
+--- a/src/llama-vocab.cpp
++++ b/src/llama-vocab.cpp
+@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+
+ const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
+ if (precompiled_charsmap_keyidx != -1) {
+- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
++ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
+ const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
+ precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
+ #ifdef IS_BIG_ENDIAN
diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp
index 0f137dc8d..b816cedd4 100644
--- a/llama/sampling_ext.cpp
+++ b/llama/sampling_ext.cpp
@@ -2,6 +2,9 @@
#include "sampling.h"
#include "sampling_ext.h"
#include "json-schema-to-grammar.h"
+#include "llama.h"
+#include "llama-model.h"
+#include "llama-model-loader.h"
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
try {
@@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
return 0;
}
}
+
+struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
+ llama_vocab * vocab = new llama_vocab();
+ try {
+ const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
+ std::vector splits = {};
+ llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
+ vocab->load(ml, kv);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
+ return nullptr;
+ }
+
+ return vocab;
+}
+
+void llama_free_vocab(struct llama_vocab * vocab) {
+ delete vocab;
+}
diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h
index 39f499f19..9be7c100e 100644
--- a/llama/sampling_ext.h
+++ b/llama/sampling_ext.h
@@ -35,6 +35,9 @@ extern "C"
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
+ struct llama_vocab * llama_load_vocab_from_file(const char * fname);
+ void llama_free_vocab(struct llama_vocab * vocab);
+
#ifdef __cplusplus
}
#endif
diff --git a/llm/memory.go b/llm/memory.go
index 1da4d2c08..40104eca9 100644
--- a/llm/memory.go
+++ b/llm/memory.go
@@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
+ if projectorWeights == 0 && projectorGraph == 0 {
+ projectorWeights, projectorGraph = f.VisionGraphSize()
+ }
layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer
diff --git a/llm/server.go b/llm/server.go
index fd027a535..a53306fb0 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llama"
+ "github.com/ollama/ollama/model"
)
type LlamaServer interface {
@@ -54,8 +55,15 @@ type llmServer struct {
options api.Options
numParallel int
modelPath string
- modelLock sync.Mutex // Temporary until we switch fully to Go server
- model *llama.Model // If non-nil, the runner is a new Go server
+
+ // llamaModel is an instance of the cgo llama.cpp model definition
+ // nil if this server is running the new engine
+ llamaModel *llama.Model
+ llamaModelLock sync.Mutex
+
+ // textProcessor handles text encoding/decoding for the model in the Ollama engine
+ // nil if this server is running the llama.cpp based engine
+ textProcessor model.TextProcessor
estimate MemoryEstimate
totalLayers uint64
@@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
-func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
+func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
systemInfo := discover.GetSystemInfo()
systemTotalMemory := systemInfo.System.TotalMemory
systemFreeMemory := systemInfo.System.FreeMemory
@@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
slog.Info("offload", "", estimate)
params := []string{
- "--model", model,
+ "--model", modelPath,
"--ctx-size", strconv.Itoa(opts.NumCtx),
"--batch-size", strconv.Itoa(opts.NumBatch),
}
@@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
}
}
- if len(projectors) > 0 {
- // TODO: applying multiple projectors is not supported by the llama.cpp server yet
- params = append(params, "--mmproj", projectors[0])
- }
-
defaultThreads := systemInfo.GetOptimalThreadCount()
if opts.NumThread > 0 {
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
@@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
}
}
slog.Debug("compatible gpu libraries", "compatible", compatible)
+ exe, err := os.Executable()
+ if err != nil {
+ return nil, fmt.Errorf("unable to lookup executable path: %w", err)
+ }
+
+ if eval, err := filepath.EvalSymlinks(exe); err == nil {
+ exe = eval
+ }
+
+ var llamaModel *llama.Model
+ var textProcessor model.TextProcessor
+ if envconfig.NewEngine() {
+ textProcessor, err = model.NewTextProcessor(modelPath)
+ if err != nil {
+ // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
+ slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
+ }
+ }
+ if textProcessor == nil {
+ llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(projectors) > 0 && llamaModel != nil {
+ params = append(params, "--mmproj", projectors[0])
+ }
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
@@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
finalParams := []string{"runner"}
- if envconfig.NewEngine() {
+ if textProcessor != nil {
+ // New engine
+ // TODO - if we have failure to load scenarios, add logic to retry with the old runner
finalParams = append(finalParams, "--ollama-engine")
}
finalParams = append(finalParams, params...)
@@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
// finally, add the root library path
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
- exe, err := os.Executable()
- if err != nil {
- return nil, fmt.Errorf("unable to lookup executable path: %w", err)
- }
-
- if eval, err := filepath.EvalSymlinks(exe); err == nil {
- exe = eval
- }
-
- // TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
s := &llmServer{
- port: port,
- cmd: exec.Command(exe, finalParams...),
- status: NewStatusWriter(os.Stderr),
- options: opts,
- modelPath: model,
- estimate: estimate,
- numParallel: numParallel,
- sem: semaphore.NewWeighted(int64(numParallel)),
- totalLayers: f.KV().BlockCount() + 1,
- gpus: gpus,
- done: make(chan error, 1),
+ port: port,
+ cmd: exec.Command(exe, finalParams...),
+ status: NewStatusWriter(os.Stderr),
+ options: opts,
+ modelPath: modelPath,
+ llamaModel: llamaModel,
+ textProcessor: textProcessor,
+ estimate: estimate,
+ numParallel: numParallel,
+ sem: semaphore.NewWeighted(int64(numParallel)),
+ totalLayers: f.KV().BlockCount() + 1,
+ gpus: gpus,
+ done: make(chan error, 1),
}
s.cmd.Env = os.Environ()
@@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
}
err := fmt.Errorf("error starting runner: %v %s", err, msg)
if len(compatible) == 0 {
+ if llamaModel != nil {
+ llama.FreeModel(llamaModel)
+ }
return nil, err
}
@@ -933,64 +961,25 @@ type TokenizeResponse struct {
}
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
- s.modelLock.Lock()
- defer s.modelLock.Unlock()
- if s.model != nil {
- return s.model.Tokenize(content, false, true)
- }
+ s.llamaModelLock.Lock()
+ defer s.llamaModelLock.Unlock()
- // Make sure the server is ready
- status, err := s.getServerStatus(ctx)
- if err != nil {
- return nil, err
- } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
- return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
+ if s.llamaModel != nil {
+ return s.llamaModel.Tokenize(content, false, true)
}
-
- data, err := json.Marshal(TokenizeRequest{Content: content})
- if err != nil {
- return nil, fmt.Errorf("marshaling encode data: %w", err)
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
- if err != nil {
- return nil, fmt.Errorf("encode request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
-
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return nil, fmt.Errorf("do encode request: %w", err)
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusNotFound {
- if s.model == nil {
- slog.Debug("new runner detected, loading model for cgo tokenization")
- m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
- if err != nil {
- return nil, err
- }
- s.model = m
+ if s.textProcessor != nil {
+ tokens, err := s.textProcessor.Encode(content, false)
+ if err != nil {
+ return nil, err
}
- return s.model.Tokenize(content, false, true)
+ toks := make([]int, len(tokens))
+ for i, t := range tokens {
+ toks[i] = int(t)
+ }
+ return toks, nil
}
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("read encode request: %w", err)
- }
-
- if resp.StatusCode >= 400 {
- log.Printf("llm encode error: %s", body)
- return nil, fmt.Errorf("%s", body)
- }
-
- var encoded TokenizeResponse
- if err := json.Unmarshal(body, &encoded); err != nil {
- return nil, fmt.Errorf("unmarshal encode response: %w", err)
- }
-
- return encoded.Tokens, nil
+ // not reached
+ return nil, fmt.Errorf("no tokenizer configured")
}
type DetokenizeRequest struct {
@@ -1002,80 +991,38 @@ type DetokenizeResponse struct {
}
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
- s.modelLock.Lock()
- defer s.modelLock.Unlock()
- if s.model != nil {
+ s.llamaModelLock.Lock()
+ defer s.llamaModelLock.Unlock()
+
+ if s.llamaModel != nil {
var resp string
for _, token := range tokens {
- resp += s.model.TokenToPiece(token)
+ resp += s.llamaModel.TokenToPiece(token)
}
return resp, nil
}
- // Make sure the server is ready
- status, err := s.getServerStatus(ctx)
- if err != nil {
- return "", err
- } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
- return "", fmt.Errorf("unexpected server status: %s", status.ToString())
- }
-
- data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
- if err != nil {
- return "", fmt.Errorf("marshaling decode data: %w", err)
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
- if err != nil {
- return "", fmt.Errorf("decode request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
-
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return "", fmt.Errorf("do decode request: %w", err)
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusNotFound {
- if s.model == nil {
- slog.Debug("new runner detected, loading model for cgo tokenization")
- m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
- if err != nil {
- return "", err
- }
- s.model = m
+ if s.textProcessor != nil {
+ toks := make([]int32, len(tokens))
+ for i, t := range tokens {
+ toks[i] = int32(t)
}
- var resp string
- for _, token := range tokens {
- resp += s.model.TokenToPiece(token)
+ content, err := s.textProcessor.Decode(toks)
+ if err != nil {
+ return "", err
}
- return resp, nil
+ return content, nil
}
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", fmt.Errorf("read decode request: %w", err)
- }
-
- if resp.StatusCode >= 400 {
- log.Printf("llm decode error: %s", body)
- return "", fmt.Errorf("%s", body)
- }
-
- var decoded DetokenizeResponse
- if err := json.Unmarshal(body, &decoded); err != nil {
- return "", fmt.Errorf("unmarshal encode response: %w", err)
- }
-
- return decoded.Content, nil
+ // not reached
+ return "", fmt.Errorf("no tokenizer configured")
}
func (s *llmServer) Close() error {
- s.modelLock.Lock()
- if s.model != nil {
- llama.FreeModel(s.model)
- s.model = nil
+ s.llamaModelLock.Lock()
+ if s.llamaModel != nil {
+ llama.FreeModel(s.llamaModel)
+ s.llamaModel = nil
}
- s.modelLock.Unlock()
+ s.llamaModelLock.Unlock()
if s.cmd != nil {
slog.Debug("stopping llama server")
diff --git a/ml/backend.go b/ml/backend.go
index 07bc75b64..641175f0f 100644
--- a/ml/backend.go
+++ b/ml/backend.go
@@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"os"
+ "slices"
"strconv"
"strings"
)
@@ -24,7 +25,36 @@ type Backend interface {
Config() Config
Get(name string) Tensor
NewContext() Context
- SystemInfo() string
+ NewContextSize(size int) Context
+}
+
+// BackendCacheConfig should be implemented by backends that need special output
+// from the cache to meet specific requirements. It is frequently implemented in
+// conjunction with ScaledDotProductAttention.
+type BackendCacheConfig interface {
+ CacheConfig() CacheConfig
+}
+
+// CacheConfig controls optimizations (mostly backend-specific) that may transform
+// the output the cache to work better with specific kernels.
+type CacheConfig struct {
+ // CachePadding specifies the multiple for the number of tokens of cache history
+ // that will be returned from cache Get for k, v and mask. The capacity of the
+ // cache itself will also be increased to a multiple of this size if needed.
+ CachePadding int
+
+ // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
+ // and return the permuted version via Get. This uses the cache copy operation
+ // to avoid a Contiguous call on the permuted tensor.
+ PermutedV bool
+
+ // MaskDType specifies the data type for generating the mask. If unset it will
+ // default to DTypeF32.
+ MaskDType DType
+
+ // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
+ // Any position that does not correspond to an actual token will be filled with -Inf.
+ MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
@@ -40,6 +70,9 @@ type BackendParams struct {
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
+
+ // FlashAttention indicates that we should use a fused flash attention kernel
+ FlashAttention bool
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
@@ -61,14 +94,24 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
}
type Context interface {
+ Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
Forward(...Tensor) Context
Compute(...Tensor)
- MaxTensors() int
+ MaxGraphNodes() int
Close()
+
+ // Input returns a context appropriate for creating input tensors
+ Input() Context
+
+ // Output returns a context appropriate for creating output tensors
+ Output() Context
+
+ // Layer returns a context appropriate for creating intermediate tensors
+ Layer(int) Context
}
type Tensor interface {
@@ -116,6 +159,10 @@ type Tensor interface {
// operation equivalent to following code on a tensor named
// query:
//
+// query = query.Permute(ctx, 0, 2, 1, 3)
+// key = key.Permute(ctx, 0, 2, 1, 3)
+// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
@@ -169,8 +216,8 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
- case DTypeF16:
- f32 := ctx.Zeros(DTypeF32, t.Shape()...)
+ case DTypeF16, DTypeQ80, DTypeQ40:
+ f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
@@ -195,16 +242,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
}
shape := t.Shape()
+ slices.Reverse(shape)
var sb strings.Builder
var f func([]int, int)
f = func(dims []int, stride int) {
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
- fmt.Fprint(&sb, "[")
- defer func() { fmt.Fprint(&sb, "]") }()
+ sb.WriteString("[")
+ defer func() { sb.WriteString("]") }()
for i := 0; i < dims[0]; i++ {
if i >= items && i < dims[0]-items {
- fmt.Fprint(&sb, "..., ")
+ sb.WriteString("..., ")
// skip to next printable element
skip := dims[0] - 2*items
if len(dims) > 1 {
@@ -219,9 +267,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
}
} else {
- fmt.Fprint(&sb, fn(s[stride+i]))
+ text := fn(s[stride+i])
+ if len(text) > 0 && text[0] != '-' {
+ sb.WriteString(" ")
+ }
+
+ sb.WriteString(text)
if i < dims[0]-1 {
- fmt.Fprint(&sb, ", ")
+ sb.WriteString(", ")
}
}
}
@@ -237,5 +290,7 @@ const (
DTypeOther DType = iota
DTypeF32
DTypeF16
+ DTypeQ80
+ DTypeQ40
DTypeI32
)
diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go
index 7f91990c3..74512f337 100644
--- a/ml/backend/ggml/ggml.go
+++ b/ml/backend/ggml/ggml.go
@@ -1,89 +1,61 @@
package ggml
-/*
-#cgo CPPFLAGS: -I${SRCDIR}/ggml/include
-#include
-#include
-#include "ggml.h"
-#include "ggml-cpu.h"
-#include "ggml-backend.h"
-static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);}
-static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];}
-
-typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
-COMPILER inline get_compiler() {
-#if defined(__clang__)
- return COMP_CLANG;
-#elif defined(__GNUC__)
- return COMP_GCC;
-#else
- return UNKNOWN_COMPILER;
-#endif
-}
-
-*/
+// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
+// #include
+// #include
+// #include "ggml.h"
+// #include "ggml-cpu.h"
+// #include "ggml-backend.h"
import "C"
import (
+ "errors"
"fmt"
"io"
"log/slog"
+ "maps"
"os"
- "sync"
+ "slices"
+ "strconv"
+ "strings"
+ "unicode"
"unsafe"
"github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
- "golang.org/x/sync/errgroup"
-
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
+ "golang.org/x/sync/errgroup"
)
-type device struct {
- d *C.struct_ggml_backend_device
-}
-
-func (d device) LogValue() slog.Value {
- var free, total uint64
- C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
-
- kind := "unknown"
- switch C.ggml_backend_dev_type(d.d) {
- case C.GGML_BACKEND_DEVICE_TYPE_CPU:
- kind = "cpu"
- case C.GGML_BACKEND_DEVICE_TYPE_GPU:
- kind = "gpu"
- case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
- kind = "accel"
- }
-
- return slog.GroupValue(
- slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
- slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
- slog.String("kind", kind),
- slog.String("free", format.HumanBytes2(free)),
- slog.String("total", format.HumanBytes2(total)),
- )
-}
-
-var devices = sync.OnceValue(func() []device {
+func devices() []*C.struct_ggml_backend_device {
ggml.OnceLoad()
-
- s := make([]device, C.ggml_backend_dev_count())
- for i := range s {
- s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
+ ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
+ for i := range ds {
+ ds[i] = C.ggml_backend_dev_get(C.size_t(i))
}
- return s
-})
+ return ds
+}
type Backend struct {
- meta *fs.GGML
- cpus, gpus []Context
- tensors map[string]*Context
+ meta *fs.GGML
+ sched *C.struct_ggml_backend_sched
+ tensors map[string]*C.struct_ggml_tensor
- sched *C.struct_ggml_backend_sched
+ // input is the backend used for inputs
+ input *C.struct_ggml_backend_buffer_type
+
+ // output is the backend used for outputs
+ output *C.struct_ggml_backend_buffer_type
+
+ // layers is the backend used for repeating layers
+ layers map[int]*C.struct_ggml_backend_buffer_type
+
+ flashAttention bool
+
+ // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
+ maxGraphNodes int
}
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
@@ -102,106 +74,310 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
"num_key_values", len(meta.KV()),
)
- var cpus, gpus []Context
+ type deviceBufferType struct {
+ d *C.struct_ggml_backend_device
+ bts []*C.struct_ggml_backend_buffer_type
+ }
+
+ var cpus, accels, gpus []*C.struct_ggml_backend_device
for _, d := range devices() {
- switch C.ggml_backend_dev_type(d.d) {
+ switch C.ggml_backend_dev_type(d) {
+ case C.GGML_BACKEND_DEVICE_TYPE_CPU:
+ if len(cpus) == 0 {
+ // only the first cpu device should be used
+ cpus = append(cpus, d)
+ }
+ case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
+ accels = append(accels, d)
+ case C.GGML_BACKEND_DEVICE_TYPE_GPU:
+ gpus = append(gpus, d)
+ }
+ }
+
+ // create list of buffer types for the cpu
+ cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
+ for _, d := range append(accels, append(gpus, cpus...)...) {
+ switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
- slog.Info("cpu", "device", d)
- cpus = append(cpus, Context{
- ctx: C.ggml_init(C.struct_ggml_init_params{
- mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
- no_alloc: true,
- }),
- backend: C.ggml_backend_dev_init(d.d, nil),
- })
- case C.GGML_BACKEND_DEVICE_TYPE_GPU:
- slog.Info("gpu", "device", d)
- gpus = append(gpus, Context{
- ctx: C.ggml_init(C.struct_ggml_init_params{
- mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
- no_alloc: true,
- }),
- backend: C.ggml_backend_dev_init(d.d, nil),
- })
+ cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
}
}
- ctxFunc := func(s []Context) (*Context, error) {
- for _, e := range s {
- return &e, nil
- }
-
- return nil, fmt.Errorf("no devices available")
- }
-
- tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
- for _, t := range meta.Tensors().Items() {
- c, err := ctxFunc(append(gpus, cpus...))
- if err != nil {
- return nil, err
- }
-
- func() {
- tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
-
- cname := C.CString(t.Name)
- defer C.free(unsafe.Pointer(cname))
- C.ggml_set_name(tt, cname)
-
- tensors[t] = c
- }()
- }
-
- for _, b := range append(gpus, cpus...) {
- C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
- }
-
- sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
-
- var g errgroup.Group
- for t, c := range tensors {
- g.Go(func() error {
- bts := make([]byte, t.Size())
- n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
- if err != nil {
- return err
- }
-
- if n != int(t.Size()) {
- return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
- }
-
- cname := C.CString(t.Name)
- defer C.free(unsafe.Pointer(cname))
-
- C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
- return nil
+ // create list of buffer types for each gpu
+ var gpuDeviceBufferTypes []deviceBufferType
+ for _, d := range gpus {
+ bt := C.ggml_backend_dev_buffer_type(d)
+ gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
+ d: d,
+ bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
})
}
- if err := g.Wait(); err != nil {
+ useDefaultSplit := true
+ for _, s := range params.TensorSplit {
+ if s != 0 {
+ useDefaultSplit = false
+ break
+ }
+ }
+
+ // calculate splits
+ splits := make([]float32, len(gpus))
+ if useDefaultSplit {
+ // default: split on free memory
+ for i := range splits {
+ var free, total C.size_t
+ C.ggml_backend_dev_memory(gpus[i], &free, &total)
+ splits[i] = float32(free)
+ }
+ } else {
+ splits = params.TensorSplit
+ }
+
+ var sum float32
+ // cumulative sum of all splits
+ for i := range splits {
+ sum += splits[i]
+ splits[i] = sum
+ }
+
+ // normalize splits
+ for i := range splits {
+ splits[i] /= sum
+ }
+
+ // inputs always use cpu
+ input := cpuDeviceBufferType
+
+ blocks := int(meta.KV().BlockCount())
+
+ // define a range of gpu layers. anything outside of this range is assigned to the cpu
+ gpuRangeStart := max(0, blocks-params.NumGPULayers)
+ gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
+ assignLayer := func(i int) deviceBufferType {
+ if i < gpuRangeStart || i >= gpuRangeStop {
+ return cpuDeviceBufferType
+ }
+
+ index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
+ if index < 0 || index >= len(gpuDeviceBufferTypes) {
+ return cpuDeviceBufferType
+ }
+
+ return gpuDeviceBufferTypes[index]
+ }
+
+ // repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
+ layers := make([]deviceBufferType, blocks)
+ for i := range layers {
+ layers[i] = assignLayer(i)
+ }
+
+ // outputs are assigned iff allowed by splits and configured number of gpu layers
+ output := assignLayer(blocks)
+
+ maxTensors := len(meta.Tensors().Items())
+ maxTensors += 1
+ // each layer has at most 2 extra tensors for rope operations
+ maxTensors += blocks * 2
+
+ type tensor struct {
+ source *fs.Tensor
+ target string
+ }
+
+ // some tensors are mapped to different names so keep a list
+ targets := make(map[string][]string)
+
+ // contexts are shared by tensors of the same buffer type
+ ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
+ createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
+ for _, bt := range bts {
+ if _, ok := ctxs[bt]; !ok {
+ ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
+ mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
+ no_alloc: true,
+ })
+ }
+
+ targets[t.source.Name] = append(targets[t.source.Name], t.target)
+
+ name := t.source.Name
+ if t.target != "" {
+ name = t.target
+ }
+
+ cname := C.CString(name)
+ defer C.free(unsafe.Pointer(cname))
+ if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
+ return tt
+ }
+
+ tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
+ C.ggml_set_name(tt, cname)
+
+ slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
+ //nolint:staticcheck // TODO: check if buffer type supports this tensor
+ return tt
+ }
+
+ return nil
+ }
+
+ contains := func(s string, parts ...string) bool {
+ split := strings.Split(s, ".")
+ for _, part := range parts {
+ if slices.Contains(split, part) {
+ return true
+ }
+ }
+
+ return false
+ }
+
+ for _, t := range meta.Tensors().Items() {
+ switch {
+ case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
+ createTensor(tensor{source: t}, input.bts)
+ case contains(t.Name, "cls", "output", "output_norm"):
+ createTensor(tensor{source: t}, output.bts)
+ case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
+ // TODO: assign vision tensors to the gpu if possible
+ createTensor(tensor{source: t}, input.bts)
+ default:
+ layerIndex := -1
+ if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
+ if i, err := strconv.Atoi(fields[0]); err == nil {
+ layerIndex = i
+ }
+ }
+
+ if layerIndex >= 0 {
+ createTensor(tensor{source: t}, layers[layerIndex].bts)
+ } else {
+ // this is a repeating tensor that doesn't explicitly associated with a layer so
+ // duplicate it for each layer
+ for i, layer := range layers {
+ createTensor(tensor{
+ source: t,
+ target: "blk." + strconv.Itoa(i) + "." + t.Name,
+ }, layer.bts)
+ }
+ }
+ }
+ }
+
+ // allocate buffers for each context
+ bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
+ for bt, c := range ctxs {
+ if C.ggml_get_first_tensor(c) == nil {
+ continue
+ }
+
+ b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
+ C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
+ bbs[c] = b
+ }
+
+ for bs := range maps.Values(bbs) {
+ slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
+ }
+
+ // map tensor names to tensors for easy lookup later
+ tensors := make(map[string]*C.struct_ggml_tensor)
+ for _, c := range ctxs {
+ for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
+ tensors[C.GoString(C.ggml_get_name(t))] = t
+ }
+ }
+
+ // concurrently read in tensor data. uses a section reader which is safe for concurrent reads
+ sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
+ var g errgroup.Group
+ for _, t := range meta.Tensors().Items() {
+ for _, target := range targets[t.Name] {
+ g.Go(func() error {
+ if target == "" {
+ target = t.Name
+ }
+
+ tt, ok := tensors[target]
+ if !ok {
+ return fmt.Errorf("unassigned tensor: %s", t.Name)
+ }
+
+ bts := make([]byte, t.Size())
+ n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
+ if err != nil {
+ return err
+ }
+
+ if n != len(bts) {
+ return errors.New("short read")
+ }
+
+ C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
+ return nil
+ })
+ }
+ }
+
+ if g.Wait() != nil {
return nil, err
}
- backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
- bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
- for i, c := range append(gpus, cpus...) {
- backends[i] = c.backend
- bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
+ // map devices to backend buffer types so new tensors can be assigned to the correct device
+ deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
+
+ // create backends and buffer types used for the compute graph scheduler
+ var schedBackends []*C.struct_ggml_backend
+ var schedBufts []*C.struct_ggml_backend_buffer_type
+ for _, d := range append(gpus, append(accels, cpus...)...) {
+ b := C.ggml_backend_dev_init(d, nil)
+ bt := C.ggml_backend_get_default_buffer_type(b)
+ if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
+ // use the first gpu host buffer type for gpu if possible
+ if hbt := C.ggml_backend_dev_host_buffer_type(gpus[0]); hbt != nil {
+ bt = hbt
+ }
+ }
+
+ deviceBufferTypes[d] = bt
+
+ schedBackends = append(schedBackends, b)
+ schedBufts = append(schedBufts, bt)
+
+ slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
+
+ if C.ggml_backend_is_cpu(b) {
+ // set number of threads for cpu backend
+ C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
+ }
}
+ maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
return &Backend{
- meta: meta,
- cpus: cpus,
- gpus: gpus,
+ flashAttention: params.FlashAttention,
+ meta: meta,
+ tensors: tensors,
sched: C.ggml_backend_sched_new(
- (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
- (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
- C.int(len(backends)),
- C.size_t(max(8192, len(meta.Tensors().Items())*5)),
+ (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
+ (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
+ C.int(len(schedBackends)),
+ C.size_t(maxGraphNodes),
true,
),
+ input: deviceBufferTypes[input.d],
+ output: deviceBufferTypes[output.d],
+ layers: func() map[int]*C.struct_ggml_backend_buffer_type {
+ m := make(map[int]*C.struct_ggml_backend_buffer_type)
+ for i, layer := range layers {
+ m[i] = deviceBufferTypes[layer.d]
+ }
+ return m
+ }(),
+ maxGraphNodes: maxGraphNodes,
}, nil
}
@@ -214,51 +390,95 @@ func (b *Backend) Config() ml.Config {
}
func (b *Backend) Get(name string) ml.Tensor {
- cname := C.CString(name)
- defer C.free(unsafe.Pointer(cname))
-
- for _, c := range append(b.gpus, b.cpus...) {
- if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
- return &Tensor{t: t}
- }
+ if t, ok := b.tensors[name]; ok {
+ return &Tensor{b: b, t: t}
}
return nil
}
func (b *Backend) NewContext() ml.Context {
- nodes := max(8192, len(b.meta.Tensors().Items())*5)
- c := C.ggml_init(C.struct_ggml_init_params{
- mem_buffer: nil,
- mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
- no_alloc: true,
- })
+ return b.NewContextSize(b.maxGraphNodes)
+}
- backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
- for i, c := range append(b.gpus, b.cpus...) {
- backends[i] = c.backend
+func (b *Backend) NewContextSize(n int) ml.Context {
+ if n > b.maxGraphNodes {
+ panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
}
return &Context{
- b: b,
- ctx: c,
- backend: backends[0],
- nodes: nodes,
+ b: b,
+ maxGraphNodes: n,
+ ctx: C.ggml_init(C.struct_ggml_init_params{
+ mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
+ no_alloc: true,
+ }),
+ }
+}
+
+func (b *Backend) CacheConfig() ml.CacheConfig {
+ if b.flashAttention {
+ return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
+ } else {
+ return ml.CacheConfig{CachePadding: 32, PermutedV: true}
}
}
type Context struct {
- b *Backend
- ctx *C.struct_ggml_context
- backend *C.struct_ggml_backend
+ b *Backend
+ ctx *C.struct_ggml_context
graph *C.struct_ggml_cgraph
- nodes int
+
+ // buft is the buffer type used for new tensors
+ buft *C.struct_ggml_backend_buffer_type
+
+ // maxGraphNodes is the maximum allowed number of graph nodes in this context
+ maxGraphNodes int
+}
+
+func (c Context) Input() ml.Context {
+ if c.b.input != nil {
+ return &Context{
+ b: c.b,
+ ctx: c.ctx,
+ buft: c.b.input,
+ maxGraphNodes: c.maxGraphNodes,
+ }
+ }
+
+ return &c
+}
+
+func (c Context) Output() ml.Context {
+ if c.b.output != nil {
+ return &Context{
+ b: c.b,
+ ctx: c.ctx,
+ buft: c.b.output,
+ maxGraphNodes: c.maxGraphNodes,
+ }
+ }
+
+ return &c
+}
+
+func (c Context) Layer(i int) ml.Context {
+ if buft, ok := c.b.layers[i]; ok {
+ return &Context{
+ b: c.b,
+ ctx: c.ctx,
+ buft: buft,
+ maxGraphNodes: c.maxGraphNodes,
+ }
+ }
+
+ return &c
}
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
if c.graph == nil {
- c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
+ c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
}
for _, tensor := range tensors {
@@ -268,7 +488,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
return c
}
-func (c *Context) Compute(tensors ...ml.Tensor) {
+func (c Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
@@ -287,21 +507,48 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
}
}
-func (c *Context) MaxTensors() int {
- return c.nodes
+func (c Context) MaxGraphNodes() int {
+ return c.maxGraphNodes
}
func shapeToGGML(shape []int) *C.int64_t {
sh := make([]C.int64_t, len(shape))
for i, s := range shape {
- sh[i] = (C.int64_t)(s)
+ sh[i] = C.int64_t(s)
}
return &sh[0]
}
-func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
- if len(shape) < 1 || len(shape) > 4 {
+func pad(length, pad C.size_t) C.size_t {
+ return ((length + pad - 1) / pad) * pad
+}
+
+func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
+ if c.buft == nil {
+ panic("set Input, Output, or Layer before creating tensors")
+ }
+
+ var cdtype uint32
+ switch dtype {
+ case ml.DTypeF32:
+ cdtype = C.GGML_TYPE_F32
+ case ml.DTypeF16:
+ cdtype = C.GGML_TYPE_F16
+ case ml.DTypeQ80:
+ cdtype = C.GGML_TYPE_Q8_0
+ case ml.DTypeQ40:
+ cdtype = C.GGML_TYPE_Q4_0
+ case ml.DTypeI32:
+ cdtype = C.GGML_TYPE_I32
+ default:
+ panic("unsupported dtype")
+ }
+
+ if len(shape) < 1 || shape[0] == 0 {
+ var shape C.int64_t = 0
+ return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
+ } else if len(shape) > 4 {
panic("unsupported number of dimensions")
}
@@ -311,31 +558,28 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
}
}
- var t *C.struct_ggml_tensor
- switch dtype {
- case ml.DTypeF32:
- t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
- case ml.DTypeF16:
- t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
- case ml.DTypeI32:
- t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
- default:
- panic("unsupported dtype")
- }
-
- b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
+ t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
+ size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
+ b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
- C.ggml_set_zero(t)
- return &Tensor{t: t}
+ return &Tensor{b: c.b, t: t}
}
-func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
+func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
+ return c.newTensor(dtype, shape)
+}
+
+func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
+ t := c.newTensor(dtype, shape)
+ C.ggml_set_zero(t.(*Tensor).t)
+ return t
+}
+
+func checkShape[S ~[]E, E any](s S, shape ...int) error {
n := len(s)
if n == 0 {
- var shape C.int64_t = 0
- t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
- return &Tensor{t: t}, nil
+ return nil
}
for _, v := range shape {
@@ -343,22 +587,36 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
}
if n != 1 {
- return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
+ return fmt.Errorf("invalid shape: %v", shape)
}
- t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
- b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
- C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
- C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
- return &Tensor{t: t}, nil
+ return nil
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
- return fromSlice(c, s, shape, C.GGML_TYPE_F32)
+ if err := checkShape(s, shape...); err != nil {
+ return nil, err
+ }
+
+ t := c.newTensor(ml.DTypeF32, shape)
+ if len(s) > 0 {
+ C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+ }
+
+ return t, nil
}
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
- return fromSlice(c, s, shape, C.GGML_TYPE_I32)
+ if err := checkShape(s, shape...); err != nil {
+ return nil, err
+ }
+
+ t := c.newTensor(ml.DTypeI32, shape)
+ if len(s) > 0 {
+ C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+ }
+
+ return t, nil
}
func (c *Context) Close() {
@@ -368,6 +626,7 @@ func (c *Context) Close() {
}
type Tensor struct {
+ b *Backend
t *C.struct_ggml_tensor
sync func()
}
@@ -425,6 +684,10 @@ func (t *Tensor) DType() ml.DType {
return ml.DTypeF32
case C.GGML_TYPE_F16:
return ml.DTypeF16
+ case C.GGML_TYPE_Q8_0:
+ return ml.DTypeQ80
+ case C.GGML_TYPE_Q4_0:
+ return ml.DTypeQ40
case C.GGML_TYPE_I32:
return ml.DTypeI32
default:
@@ -434,6 +697,7 @@ func (t *Tensor) DType() ml.DType {
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -448,24 +712,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
}
}
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -475,12 +743,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
+ b: t.b,
t: mul,
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
- tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
+ tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {
tt = tt.Add(ctx, b)
}
@@ -489,7 +758,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
- return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
+ return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -498,6 +767,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
+ b: t.b,
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
@@ -508,18 +778,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
+ b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -528,18 +801,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
+ b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
}
case 2:
return &Tensor{
+ b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
}
case 3:
return &Tensor{
+ b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
}
case 4:
return &Tensor{
+ b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
}
default:
@@ -549,18 +826,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
}
}
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
}
}
@@ -571,6 +851,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
+ b: t.b,
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
@@ -579,10 +860,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
+ b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
}
case 3:
return &Tensor{
+ b: t.b,
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]),
@@ -590,6 +873,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
case 5:
return &Tensor{
+ b: t.b,
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]),
@@ -597,6 +881,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
case 7:
return &Tensor{
+ b: t.b,
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
@@ -613,7 +898,7 @@ const (
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
- ropeFactors = &Tensor{}
+ ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
@@ -622,6 +907,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
return &Tensor{
+ b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
@@ -639,18 +925,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
+ b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
}
}
@@ -661,42 +950,23 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
kqMask = mask.(*Tensor).t
}
- kq := key.MulmatFullPrec(ctx, t)
- kq = &Tensor{
- t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
- }
+ query := t.Permute(ctx, 0, 2, 1, 3)
+ key = key.Permute(ctx, 0, 2, 1, 3)
- kqv := value.Mulmat(ctx, kq)
- return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-}
+ if t.b.flashAttention {
+ value = value.Permute(ctx, 0, 2, 1, 3)
-func (b *Backend) SystemInfo() string {
- var compiler string
- switch C.get_compiler() {
- case C.COMP_UNKNOWN:
- compiler = "cgo(unknown_compiler)"
- case C.COMP_GCC:
- compiler = "cgo(gcc)"
- case C.COMP_CLANG:
- compiler = "cgo(clang)"
- }
-
- var s string
- for i := range C.ggml_backend_reg_count() {
- reg := C.ggml_backend_reg_get(i)
- fName := C.CString("ggml_backend_get_features")
- defer C.free(unsafe.Pointer(fName))
- get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName)
- if get_features_fn != nil {
- s += C.GoString(C.ggml_backend_reg_name(reg))
- s += " : "
- for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) {
- s += C.GoString(features.name)
- s += " = "
- s += C.GoString(features.value)
- s += " | "
- }
+ kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
+ C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
+ return &Tensor{b: t.b, t: kqv}
+ } else {
+ kq := key.MulmatFullPrec(ctx, query)
+ kq = &Tensor{
+ b: t.b,
+ t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
+
+ kqv := value.Mulmat(ctx, kq)
+ return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
- return s + compiler
}
diff --git a/ml/backend/ggml/ggml/include/gguf.h b/ml/backend/ggml/ggml/include/gguf.h
index 79ee20206..3efb22f01 100644
--- a/ml/backend/ggml/ggml/include/gguf.h
+++ b/ml/backend/ggml/ggml/include/gguf.h
@@ -114,6 +114,7 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
index c854e6bb2..799af5f3a 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
@@ -484,33 +484,29 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
- try {
- if (entry.is_regular_file()) {
- std::string filename = entry.path().filename().string();
- std::string ext = entry.path().extension().string();
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path()) };
- if (!handle) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
- continue;
- }
+ if (entry.is_regular_file()) {
+ std::string filename = entry.path().filename().string();
+ std::string ext = entry.path().extension().string();
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
+ if (!handle) {
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
+ continue;
+ }
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (!score_fn) {
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
- continue;
- }
+ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+ if (!score_fn) {
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
+ continue;
+ }
- int s = score_fn();
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
- if (s > best_score) {
- best_score = s;
- best_path = entry.path();
- }
+ int s = score_fn();
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
+ if (s > best_score) {
+ best_score = s;
+ best_path = entry.path();
}
}
- } catch (const std::exception & e) {
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}
@@ -533,6 +529,14 @@ void ggml_backend_load_all() {
ggml_backend_load_all_from_path(nullptr);
}
+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
+ try {
+ ggml_backend_load_best(name, silent, user_search_path);
+ } catch (const std::exception & e) {
+ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
+ }
+}
+
void ggml_backend_load_all_from_path(const char * dir_path) {
#ifdef NDEBUG
bool silent = true;
@@ -540,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
bool silent = false;
#endif
- ggml_backend_load_best("blas", silent, dir_path);
- ggml_backend_load_best("cann", silent, dir_path);
- ggml_backend_load_best("cuda", silent, dir_path);
- ggml_backend_load_best("hip", silent, dir_path);
- ggml_backend_load_best("kompute", silent, dir_path);
- ggml_backend_load_best("metal", silent, dir_path);
- ggml_backend_load_best("rpc", silent, dir_path);
- ggml_backend_load_best("sycl", silent, dir_path);
- ggml_backend_load_best("vulkan", silent, dir_path);
- ggml_backend_load_best("opencl", silent, dir_path);
- ggml_backend_load_best("musa", silent, dir_path);
- ggml_backend_load_best("cpu", silent, dir_path);
+ ggml_backend_try_load_best("blas", silent, dir_path);
+ ggml_backend_try_load_best("cann", silent, dir_path);
+ ggml_backend_try_load_best("cuda", silent, dir_path);
+ ggml_backend_try_load_best("hip", silent, dir_path);
+ ggml_backend_try_load_best("kompute", silent, dir_path);
+ ggml_backend_try_load_best("metal", silent, dir_path);
+ ggml_backend_try_load_best("rpc", silent, dir_path);
+ ggml_backend_try_load_best("sycl", silent, dir_path);
+ ggml_backend_try_load_best("vulkan", silent, dir_path);
+ ggml_backend_try_load_best("opencl", silent, dir_path);
+ ggml_backend_try_load_best("musa", silent, dir_path);
+ ggml_backend_try_load_best("cpu", silent, dir_path);
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
if (backend_path) {
diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go
index 85c693eba..afc1e1edd 100644
--- a/ml/backend/ggml/ggml/src/ggml.go
+++ b/ml/backend/ggml/ggml/src/ggml.go
@@ -7,6 +7,20 @@ package ggml
// #include
// #include "ggml-backend.h"
// extern void sink(int level, char *text, void *user_data);
+// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); }
+// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; }
+/*
+typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER;
+static COMPILER compiler_name(void) {
+#if defined(__clang__)
+ return COMPILER_CLANG;
+#elif defined(__GNUC__)
+ return COMPILER_GNUC;
+#else
+ return COMPILER_UNKNOWN;
+#endif
+}
+*/
import "C"
import (
@@ -16,6 +30,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "strconv"
"strings"
"sync"
"unsafe"
@@ -90,4 +105,43 @@ var OnceLoad = sync.OnceFunc(func() {
visited[abspath] = struct{}{}
}
}
+
+ slog.Info("system", "", system{})
})
+
+type system struct{}
+
+func (system) LogValue() slog.Value {
+ var attrs []slog.Attr
+ names := make(map[string]int)
+ for i := range C.ggml_backend_dev_count() {
+ r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i))
+
+ func() {
+ fName := C.CString("ggml_backend_get_features")
+ defer C.free(unsafe.Pointer(fName))
+
+ if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil {
+ var features []any
+ for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) {
+ features = append(features, C.GoString(f.name), C.GoString(f.value))
+ }
+
+ name := C.GoString(C.ggml_backend_reg_name(r))
+ attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...))
+ names[name] += 1
+ }
+ }()
+ }
+
+ switch C.compiler_name() {
+ case C.COMPILER_CLANG:
+ attrs = append(attrs, slog.String("compiler", "cgo(clang)"))
+ case C.COMPILER_GNUC:
+ attrs = append(attrs, slog.String("compiler", "cgo(gcc)"))
+ default:
+ attrs = append(attrs, slog.String("compiler", "cgo(unknown)"))
+ }
+
+ return slog.GroupValue(attrs...)
+}
diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp
index ab13669c5..f75b923fd 100644
--- a/ml/backend/ggml/ggml/src/gguf.cpp
+++ b/ml/backend/ggml/ggml/src/gguf.cpp
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
+size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ return ctx->kv[key_id].data.size();
+}
+
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
diff --git a/ml/nn/attention.go b/ml/nn/attention.go
index 4f0c9fa14..a3f43a1ea 100644
--- a/ml/nn/attention.go
+++ b/ml/nn/attention.go
@@ -3,6 +3,7 @@ package nn
import (
"fmt"
+ "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
@@ -11,40 +12,50 @@ import (
//
// Parameters:
// - ctx: Context for tensor operations
-// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
-// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
-// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
-// - mask: Optional attention mask that is added to the attention score. If
-// provided, should broadcast to [seq_len_k, seq_len_q, heads]
+// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
+// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
+// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
+// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
-func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
- if query.Dim(0) != key.Dim(0) {
- panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
+func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
+ if key != nil && value != nil {
+ if query.Dim(0) != key.Dim(0) {
+ panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
+ }
+
+ if key.Dim(1) != value.Dim(1) {
+ panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
+ }
+
+ if key.Dim(2) != value.Dim(2) {
+ panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
+ }
+
+ if cache != nil {
+ cache.Put(ctx, key, value)
+ }
+ } else if cache == nil {
+ panic("key & value tensors must be provided if cache is nil")
}
- if mask != nil && query.Dim(1) != mask.Dim(1) {
- panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
+ var mask ml.Tensor
+ if cache != nil {
+ key, value, mask = cache.Get(ctx)
}
- if key.Dim(1) != value.Dim(0) {
- panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
- }
-
- if mask != nil && key.Dim(1) != mask.Dim(0) {
- panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
- }
-
- if key.Dim(2) != value.Dim(2) {
- panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
- }
-
- if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
+ // Only use the fast SDPA implementation if we have a cache, since that's what
+ // will do any expected backend-specific transformations for us
+ if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
+ query = query.Permute(ctx, 0, 2, 1, 3)
+ key = key.Permute(ctx, 0, 2, 1, 3)
+ value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
diff --git a/model/input/input.go b/model/input/input.go
new file mode 100644
index 000000000..0cb3f3f41
--- /dev/null
+++ b/model/input/input.go
@@ -0,0 +1,37 @@
+package input
+
+// Input represents one token in the input stream
+type Input struct {
+ // Token is a single element of text.
+ Token int32
+
+ // Multimodal is opaque data representing a non-text
+ // element such as an image (or part of one if the image
+ // can be processed in pieces). It may be either together
+ // with Token or on its own.
+ Multimodal any
+
+ // MultimodalHash is a unique representation of the data
+ // stored in Multimodal, used for caching and comparing
+ // equality.
+ MultimodalHash uint64
+}
+
+// MultimodalIndex is a multimodal element (such as an image)
+// together with an index into the slice of Inputs with the
+// corresponding token. Note that the index is not the same
+// as the position - to find that use the index with the
+// Positions slice.
+type MultimodalIndex struct {
+ Index int
+ Multimodal any
+}
+
+// Options contains the inputs for a model forward pass
+type Options struct {
+ Inputs []int32
+ Multimodal []MultimodalIndex
+ Positions []int32
+ Sequences []int
+ Outputs []int32
+}
diff --git a/model/model.go b/model/model.go
index 16020b354..89b6c803b 100644
--- a/model/model.go
+++ b/model/model.go
@@ -3,7 +3,6 @@ package model
import (
"errors"
"fmt"
- "image"
_ "image/jpeg"
_ "image/png"
"log/slog"
@@ -16,23 +15,50 @@ import (
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
+ fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
+ "github.com/ollama/ollama/model/input"
)
-// Options contains the inputs for a model forward pass
-type Options struct {
- Inputs []int32
- Positions []int32
- Sequences []int
- Outputs []int32
+// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
+type Model interface {
+ Forward(ml.Context, input.Options) (ml.Tensor, error)
- Images []image.Image
+ Backend() ml.Backend
+ Config() config
}
-type config struct {
- Cache kvcache.Cache
+// MultimodalProcessor must be implemented by multimodal models.
+type MultimodalProcessor interface {
+ // EncodeMultimodal processes a single input (such as an image) and
+ // generates an output (typically an embedding) that can be used by the model.
+ //
+ // The return value is most typically an ml.Tensor, however, different
+ // type are possible, such as an object containing a tensor plus
+ // additional metadata, a slice of tensors or even just the original input.
+ //
+ // The result may be cached by the runner.
+ EncodeMultimodal(ml.Context, []byte) (any, error)
+
+ // PostTokenize is called after tokenization to allow the model to edit the
+ // input stream to correctly arrange multimodal elements.
+ //
+ // The input is a slice of tokens with the results of EncodeMultimodal interleaved
+ // in the order that the user provided them. Each element of the slice will be
+ // either a single token or single multimodal object.
+ //
+ // The model must ensure that inputs are stored according to how they will be
+ // processed and stored in the cache. For example, Llava-style models should insert
+ // placeholder tokens equal to the feature size of the corresponding image with
+ // the image itself attached to and split across these tokens. When Forward is called
+ // a partial subset of these tokens may be submitted according to the batch size.
+ //
+ // This function is also responsible for updating MultimodalHash for any Multimodal
+ // that is modified to ensure that there is a unique hash value that accurately
+ // represents the contents.
+ PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -41,6 +67,10 @@ type Base struct {
config
}
+type config struct {
+ Cache kvcache.Cache
+}
+
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
@@ -50,14 +80,6 @@ func (m *Base) Config() config {
return m.config
}
-// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
-type Model interface {
- Forward(ml.Context, Options) (ml.Tensor, error)
-
- Backend() ml.Backend
- Config() config
-}
-
var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture
@@ -100,6 +122,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
+func NewTextProcessor(s string) (TextProcessor, error) {
+ r, err := os.Open(s)
+ if err != nil {
+ return nil, err
+ }
+ defer r.Close()
+ meta, _, err := fs.Decode(r, -1)
+ if err != nil {
+ return nil, err
+ }
+ return getTextProcessor(meta.KV())
+}
+
+func getTextProcessor(kv fs.KV) (TextProcessor, error) {
+ arch := kv.Architecture()
+ f, ok := models[arch]
+ if !ok {
+ return nil, fmt.Errorf("unsupported model architecture %q", arch)
+ }
+ m, err := f(kv)
+ if err != nil {
+ return nil, err
+ }
+ tp, ok := m.(TextProcessor)
+ if !ok {
+ return nil, fmt.Errorf("%v is not a TextProcessor", m)
+ }
+ return tp, nil
+}
+
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
@@ -226,7 +278,7 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
-func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
+func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
}
@@ -237,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
cache := m.Config().Cache
if cache != nil {
- err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
+ err := cache.StartForward(ctx, opts)
if err != nil {
return nil, err
}
diff --git a/model/model_test.go b/model/model_test.go
index 02b8aa3c2..354dd1d8b 100644
--- a/model/model_test.go
+++ b/model/model_test.go
@@ -3,12 +3,15 @@ package model
import (
"reflect"
"slices"
+ "strings"
"testing"
"github.com/google/go-cmp/cmp"
+ fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/model/input"
)
func TestParseTags(t *testing.T) {
@@ -134,3 +137,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
+
+func TestGetTextProcessor(t *testing.T) {
+ tp, err := getTextProcessor(fs.KV{})
+ if err == nil {
+ t.Error("expected error")
+ } else if !strings.Contains(err.Error(), "unsupported model architecture") {
+ t.Errorf("unexpected error: %v", err)
+ } else if tp != nil {
+ t.Error("expected nil tp")
+ }
+
+ models["dummy"] = func(ml.Config) (Model, error) {
+ return notTextProcessorModel{}, nil
+ }
+ tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
+ if err == nil {
+ t.Error("expected error")
+ } else if !strings.Contains(err.Error(), "not a TextProcessor") {
+ t.Errorf("unexpected error: %v", err)
+ } else if tp != nil {
+ t.Error("expected nil tp")
+ }
+}
+
+type notTextProcessorModel struct{}
+
+func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
+ panic("unimplemented")
+}
+
+func (notTextProcessorModel) Backend() ml.Backend {
+ panic("unimplemented")
+}
+
+func (notTextProcessorModel) Config() config {
+ panic("unimplemented")
+}
diff --git a/model/models/llama/model.go b/model/models/llama/model.go
index 6106af867..1f27f522d 100644
--- a/model/models/llama/model.go
+++ b/model/models/llama/model.go
@@ -1,16 +1,18 @@
package llama
import (
+ "fmt"
"math"
+ "strings"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
)
type Options struct {
- RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
@@ -29,6 +31,10 @@ type Model struct {
}
func New(c ml.Config) (model.Model, error) {
+ if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
+ return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
+ }
+
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -60,10 +66,11 @@ func New(c ml.Config) (model.Model, error) {
}
type SelfAttention struct {
- Query *nn.Linear `gguf:"attn_q"`
- Key *nn.Linear `gguf:"attn_k"`
- Value *nn.Linear `gguf:"attn_v"`
- Output *nn.Linear `gguf:"attn_output"`
+ Query *nn.Linear `gguf:"attn_q"`
+ Key *nn.Linear `gguf:"attn_k"`
+ Value *nn.Linear `gguf:"attn_v"`
+ Output *nn.Linear `gguf:"attn_output"`
+ RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
@@ -72,31 +79,24 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
- q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+ q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+ k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- cache.Put(ctx, k, v)
- k, v, mask := cache.Get(ctx)
-
- q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
-
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
- kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
+ kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
+ return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
@@ -138,18 +138,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
- inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
+ inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
- positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
+ positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
- outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+ outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go
index 9b35a2628..31ba15dfd 100644
--- a/model/models/mllama/model.go
+++ b/model/models/mllama/model.go
@@ -1,10 +1,18 @@
package mllama
import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "hash/fnv"
+ "image"
+ "slices"
+
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
)
type Model struct {
@@ -25,6 +33,10 @@ const (
)
func New(c ml.Config) (model.Model, error) {
+ // Verify unified config
+ if c.Uint("vision.block_count") == 0 {
+ return nil, fmt.Errorf("non-unified vision model not supported")
+ }
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -43,59 +55,99 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c),
}
- m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
+ encoderCache := kvcache.NewEncoderCache()
+ encoderCache.SetConfig(ml.CacheConfig{})
+ m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil
}
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
+ image, _, err := image.Decode(bytes.NewReader(multimodalData))
+ if err != nil {
+ return nil, err
+ }
+
+ f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
+ if err != nil {
+ return nil, err
+ }
+
+ pixelValues, err := ctx.Input().FromFloatSlice(f32s,
+ m.ImageProcessor.imageSize,
+ m.ImageProcessor.imageSize,
+ m.ImageProcessor.numChannels,
+ m.ImageProcessor.maxNumTiles,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
+ if err != nil {
+ return nil, err
+ }
+
+ positions := make([]int32, 1601)
+ for i := range positions {
+ positions[i] = int32(i)
+ }
+
+ positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
+ if err != nil {
+ return nil, err
+ }
+
+ crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
+ return m.Projector.Forward(ctx, crossAttentionStates), nil
+}
+
+func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+ var images []input.Input
+ fnvHash := fnv.New64a()
+
+ for i := range inputs {
+ if inputs[i].Multimodal == nil {
+ if len(images) > 0 {
+ inputs[i].Multimodal = images[0].Multimodal
+ inputs[i].MultimodalHash = images[0].MultimodalHash
+ for j := 1; j < len(images); j++ {
+ inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
+ fnvHash.Reset()
+ binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
+ binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
+ inputs[i].MultimodalHash = fnvHash.Sum64()
+ }
+ images = nil
+ }
+ } else {
+ images = append(images, inputs[i])
+ inputs[i].Token = -1
+ }
+ }
+
+ inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
+
+ return inputs, nil
+}
+
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
- if opts.Images != nil {
- f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
- if err != nil {
- return nil, err
- }
-
- pixelValues, err := ctx.FromFloatSlice(f32s,
- m.ImageProcessor.imageSize,
- m.ImageProcessor.imageSize,
- m.ImageProcessor.numChannels,
- m.ImageProcessor.maxNumTiles,
- )
- if err != nil {
- return nil, err
- }
-
- aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
- if err != nil {
- return nil, err
- }
-
- positions := make([]int32, 1601)
- for i := range positions {
- positions[i] = int32(i)
- }
-
- positionIDs, err := ctx.FromIntSlice(positions, len(positions))
- if err != nil {
- return nil, err
- }
-
- crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
- crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
+ if len(opts.Multimodal) > 0 {
+ crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
}
- inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
+ inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
- positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
+ positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
- outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+ outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go
index 003bf9cbf..373589f9e 100644
--- a/model/models/mllama/model_text.go
+++ b/model/models/mllama/model_text.go
@@ -10,10 +10,11 @@ import (
)
type TextSelfAttention struct {
- Query *nn.Linear `gguf:"attn_q"`
- Key *nn.Linear `gguf:"attn_k"`
- Value *nn.Linear `gguf:"attn_v"`
- Output *nn.Linear `gguf:"attn_output"`
+ Query *nn.Linear `gguf:"attn_q"`
+ Key *nn.Linear `gguf:"attn_k"`
+ Value *nn.Linear `gguf:"attn_v"`
+ Output *nn.Linear `gguf:"attn_output"`
+ RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
@@ -22,32 +23,28 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
- query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+ query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+ key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- cache.Put(ctx, key, value)
- key, value, mask := cache.Get(ctx)
-
- query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
-
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
- attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
+ attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- // This will only get called for layers in the cache, which are just the self attention layers
- return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+ if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
+ return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+ }
+
+ return key, nil
}
type TextMLP struct {
@@ -107,7 +104,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
- var key, value, mask ml.Tensor
+ var key, value ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -119,16 +116,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value)
- } else {
- key, value, mask = cache.Get(ctx)
}
- query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+ key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
- attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
+
+ query = query.Permute(ctx, 0, 2, 1, 3)
+ key = key.Permute(ctx, 0, 2, 1, 3)
+ value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+ kq := key.MulmatFullPrec(ctx, query)
+
+ kq = kq.Scale(ctx, scaleFactor)
+ kq = kq.Softmax(ctx)
+
+ kqv := value.Mulmat(ctx, kq)
+ attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)
@@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
}
type TextModelOptions struct {
- RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
-
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
diff --git a/model/process_text.go b/model/process_text.go
index 7083f36fd..0d75a0ed0 100644
--- a/model/process_text.go
+++ b/model/process_text.go
@@ -19,7 +19,7 @@ const (
)
type TextProcessor interface {
- Encode(string) ([]int32, error)
+ Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
}
@@ -144,7 +144,7 @@ type merge struct {
runes []rune
}
-func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
+func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
@@ -177,7 +177,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
- slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
continue
}
@@ -201,7 +200,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
- slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
continue
}
@@ -275,14 +273,13 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
- slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
}
}
}
}
}
- if len(ids) > 0 {
+ if addSpecial && len(ids) > 0 {
if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
@@ -329,6 +326,5 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
- slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}
diff --git a/model/process_text_test.go b/model/process_text_test.go
index cad1f94ff..f48303212 100644
--- a/model/process_text_test.go
+++ b/model/process_text_test.go
@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
t.Run("simple", func(t *testing.T) {
t.Parallel()
- ids, err := tokenizer.Encode("hello world")
+ ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
t.Errorf("got %q, want hello world", s)
}
- ids, err = tokenizer.Encode("hello <|end_of_text|>")
+ ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
}
for s, want := range cases {
- ids, err := tokenizer.Encode(s)
+ ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
}
for _, want := range cases {
- ids, err := tokenizer.Encode(want)
+ ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
}
for s, want := range cases {
- ids, err := tokenizer.Encode(s)
+ ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for range b.N {
- _, err := tokenizer.Encode(string(bts))
+ _, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
- ids, err := tokenizer.Encode(string(bts))
+ ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go
index 82880c980..8662afc1e 100644
--- a/runner/llamarunner/runner.go
+++ b/runner/llamarunner/runner.go
@@ -931,7 +931,6 @@ func Execute(args []string) error {
slog.Info("starting go runner")
llama.BackendInit()
- slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
server := &Server{
batchSize: *batchSize,
diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go
index e1fa98b1a..a411fddb1 100644
--- a/runner/ollamarunner/cache.go
+++ b/runner/ollamarunner/cache.go
@@ -5,12 +5,12 @@ import (
"fmt"
"log/slog"
"math"
- "reflect"
"time"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
)
type InputCache struct {
@@ -39,10 +39,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
- slots[i] = InputCacheSlot{
- Id: i,
- Inputs: make([]input, 0),
- }
+ slots[i] = InputCacheSlot{Id: i}
}
cache := model.Config().Cache
@@ -62,9 +59,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
func kvCacheTypeFromStr(s string) ml.DType {
switch s {
case "q8_0":
- panic("kv cache quantization not yet implemented")
+ return ml.DTypeQ80
case "q4_0":
- panic("kv cache quantization not yet implemented")
+ return ml.DTypeQ40
default:
return ml.DTypeF16
}
@@ -83,7 +80,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
- Inputs []input
+ Inputs []input.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@@ -92,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
-func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
+func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -143,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return slot, prompt, nil
}
-func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@@ -166,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
return longestSlot, longest, nil
}
-func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@@ -202,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
- oldestSlot.Inputs = make([]input, longest)
+ oldestSlot.Inputs = make([]input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -212,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
return oldestSlot, longest, nil
}
-func countCommonPrefix(a []input, b []input) int32 {
+func countCommonPrefix(a []input.Input, b []input.Input) int32 {
var count int32
for i := range a {
@@ -220,7 +217,7 @@ func countCommonPrefix(a []input, b []input) int32 {
break
}
- if !reflect.DeepEqual(a[i], b[i]) {
+ if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
break
}
diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go
index 99e67b4fa..0a1b73f5a 100644
--- a/runner/ollamarunner/cache_test.go
+++ b/runner/ollamarunner/cache_test.go
@@ -4,6 +4,8 @@ import (
"image"
"testing"
"time"
+
+ "github.com/ollama/ollama/model/input"
)
func TestCountCommon(t *testing.T) {
@@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct {
name string
- t1 []input
- t2 []input
+ t1 []input.Input
+ t2 []input.Input
expected int32
}{
{
name: "Equal",
- t1: []input{{token: 1}, {token: 2}, {token: 3}},
- t2: []input{{token: 1}, {token: 2}, {token: 3}},
+ t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+ t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
- t1: []input{{token: 1}},
- t2: []input{{token: 1}, {token: 2}, {token: 3}},
+ t1: []input.Input{{Token: 1}},
+ t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
- t1: []input{{image: imgA}},
- t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
+ t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
+ t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
- t1: []input{{token: 1}, {image: imgA}},
- t2: []input{{token: 1}, {image: imgA}, {token: 5}},
+ t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+ t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
+ {
+ name: "Mixed, Same Length",
+ t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+ t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
+ expected: 1,
+ },
{
name: "Empty",
- t1: []input{},
- t2: []input{{token: 1}, {token: 2}, {token: 3}},
+ t1: []input.Input{},
+ t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
- t1: []input{},
- t2: []input{},
+ t1: []input.Input{},
+ t2: []input.Input{},
expected: 0,
},
}
@@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
- prompt []input
+ prompt []input.Input
longest expected
best expected
}{
@@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
- Inputs: []input{},
+ Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
- Inputs: []input{},
+ Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
- prompt: []input{{token: 1}},
+ prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
- Inputs: []input{{token: 1}},
+ Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
- Inputs: []input{{token: 1}, {token: 2}},
+ Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
- prompt: []input{{token: 1}, {token: 2}},
+ prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
- Inputs: []input{{token: 1}, {token: 2}},
+ Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
- Inputs: []input{},
+ Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
- prompt: []input{{token: 2}},
+ prompt: []input.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
- Inputs: []input{{token: 1}, {token: 2}},
+ Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
- Inputs: []input{},
+ Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
- prompt: []input{{token: 1}},
+ prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
- Inputs: []input{{token: 1}},
+ Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
- Inputs: []input{{token: 1}, {token: 2}},
+ Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
- prompt: []input{{token: 2}, {token: 3}},
+ prompt: []input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
- Inputs: []input{{token: 1}, {token: 2}},
+ Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
- Inputs: []input{{token: 1}},
+ Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
- prompt: []input{{token: 1}, {token: 2}},
+ prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go
index db9b271e5..c1475cbb2 100644
--- a/runner/ollamarunner/runner.go
+++ b/runner/ollamarunner/runner.go
@@ -1,13 +1,12 @@
package ollamarunner
import (
- "bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
- "image"
+ "hash/maphash"
"log"
"log/slog"
"net"
@@ -27,28 +26,26 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
_ "github.com/ollama/ollama/model/models"
)
-// input is an element of the prompt to process, either a token or an image
-type input struct {
- token int32
-
- image image.Image
-}
-
type Sequence struct {
+ // ctx for allocating tensors that last the lifetime of the sequence, such as
+ // multimodal embeddings
+ ctx ml.Context
+
// batch index
iBatch int
// prompt inputs left to evaluate
- inputs []input
+ inputs []input.Input
// inputs that have been added to a batch but not yet submitted to Forward
- pendingInputs []input
+ pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@@ -101,8 +98,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
s.ready.Wait()
startTime := time.Now()
+ ctx := s.model.Backend().NewContext()
- inputs, err := s.inputs(prompt, images)
+ inputs, err := s.inputs(ctx, prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@@ -128,6 +126,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
+ ctx: ctx,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@@ -146,28 +145,31 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-] tags, tokenizing text and
// decoding images
-func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
- var inputs []input
+func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
+ var inputs []input.Input
var parts []string
var matches [][]string
- // TODO(jessegross): This can sometimes trigger for matching text in the
- // user's prompt. We previously tried to avoid it by only looking for images
- // on image models. We don't have a clear indication now but it would be better
- // to properly escape it in any case.
- re := regexp.MustCompile(`\[img-(\d+)\]`)
- parts = re.Split(prompt, -1)
- matches = re.FindAllStringSubmatch(prompt, -1)
+ multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
+ if visionModel {
+ re := regexp.MustCompile(`\[img-(\d+)\]`)
+ parts = re.Split(prompt, -1)
+ matches = re.FindAllStringSubmatch(prompt, -1)
+ } else {
+ parts = []string{prompt}
+ }
+
+ postTokenize := false
for i, part := range parts {
// text - tokenize
- tokens, err := s.model.(model.TextProcessor).Encode(part)
+ tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, err
}
for _, t := range tokens {
- inputs = append(inputs, input{token: t})
+ inputs = append(inputs, input.Input{Token: t})
}
// image - decode and store
@@ -186,12 +188,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
return nil, fmt.Errorf("invalid image index: %d", n)
}
- image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
+ imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, err
}
- inputs = append(inputs, input{image: image})
+ s.multimodalHash.Reset()
+ _, _ = s.multimodalHash.Write(images[imageIndex].Data)
+ imageHash := s.multimodalHash.Sum64()
+
+ inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
+ postTokenize = true
+ }
+ }
+
+ if visionModel && postTokenize {
+ var err error
+ inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
+ if err != nil {
+ return nil, err
}
}
@@ -236,8 +251,15 @@ type Server struct {
// KV cache
cache *InputCache
- // next sequence for prompt processing to avoid starvation
- nextSeq int
+ // multimodalHash generates hashes for comparing equality
+ // of non-text data
+ multimodalHash maphash.Hash
+
+ // vocab is a llama.cpp vocab required for gammar-based
+ // constrained generation (json mode, structured outputs)
+ // TODO: this is temporary until Ollama sampling supports
+ // constrained generation
+ vocab *sample.Vocab
}
func (s *Server) allNil() bool {
@@ -283,6 +305,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
+ seq.ctx.Close()
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@@ -310,30 +333,25 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
- var options model.Options
- imgSeq := -1
-
- seqIdx := s.nextSeq - 1
- for range s.seqs {
- seqIdx = (seqIdx + 1) % len(s.seqs)
- seq := s.seqs[seqIdx]
+ var options input.Options
+ for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
- s.removeSequence(seqIdx, "limit")
+ s.removeSequence(i, "limit")
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
- seq.cache.Inputs = []input{}
+ seq.cache.Inputs = []input.Input{}
}
- for i, input := range seq.inputs {
+ for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
@@ -345,37 +363,23 @@ func (s *Server) processBatch() error {
}
}
- if i >= s.batchSize {
+ if j >= s.batchSize {
break
}
- // TODO(jessegross): Image inputs need to be rethought - it's
- // it doesn't work well for different types of models or multiple sequences
- if input.image != nil {
- if len(seq.pendingInputs) != len(options.Images) {
- break
- }
-
- if imgSeq != seqIdx && imgSeq != -1 {
- s.nextSeq = seqIdx
- break
- }
-
- imgSeq = seqIdx
- options.Images = append(options.Images, input.image)
- seq.pendingInputs = append(seq.pendingInputs, input)
- continue
+ options.Inputs = append(options.Inputs, inp.Token)
+ if inp.Multimodal != nil {
+ options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
}
- options.Inputs = append(options.Inputs, input.token)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(options.Outputs)
- if i+1 == len(seq.inputs) {
+ if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
}
- seq.pendingInputs = append(seq.pendingInputs, input)
+ seq.pendingInputs = append(seq.pendingInputs, inp)
}
seq.inputs = seq.inputs[len(seq.pendingInputs):]
@@ -403,7 +407,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
- seq.pendingInputs = []input{}
+ seq.pendingInputs = []input.Input{}
}
// don't sample prompt processing
@@ -422,6 +426,7 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
+ slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "")
continue
}
@@ -449,7 +454,7 @@ func (s *Server) processBatch() error {
return err
}
- seq.inputs = []input{{token: token}}
+ seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
@@ -575,11 +580,30 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
+ var grammar *sample.Grammar
+ var err error
+ if req.Grammar != "" {
+ grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
+ if err != nil {
+ http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
+ return
+ }
+ }
+
+ sampler := sample.NewSampler(
+ req.Temperature,
+ req.TopK,
+ req.TopP,
+ req.MinP,
+ req.Seed,
+ grammar,
+ )
+
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
- sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
+ sampler: sampler,
embedding: false,
})
if err != nil {
@@ -786,7 +810,7 @@ func (s *Server) loadModel(
panic(err)
}
- slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
+ s.vocab = sample.NewVocab(mpath)
// TODO(jessegross): LoRA loading
if lpath.String() != "" {
@@ -818,7 +842,7 @@ func Execute(args []string) error {
batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
- _ = fs.Bool("flash-attn", false, "Enable flash attention")
+ flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on")
@@ -863,7 +887,6 @@ func Execute(args []string) error {
}
// TODO(jessegross): Parameters that need to be implemented:
- // flash-attn
// no-mmap
// mlock
@@ -878,10 +901,11 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
- NumThreads: *threads,
- NumGPULayers: *numGPULayers,
- MainGPU: *mainGPU,
- TensorSplit: tensorSplitFloats,
+ NumThreads: *threads,
+ NumGPULayers: *numGPULayers,
+ MainGPU: *mainGPU,
+ TensorSplit: tensorSplitFloats,
+ FlashAttention: *flashAttention,
}
server.ready.Add(1)
diff --git a/sample/samplers.go b/sample/samplers.go
index 1b8a5edd9..aea99b3f2 100644
--- a/sample/samplers.go
+++ b/sample/samplers.go
@@ -3,118 +3,226 @@ package sample
import (
"errors"
"math"
+ "math/rand/v2"
+ "slices"
+ "sync"
- "golang.org/x/exp/rand"
- "gonum.org/v1/gonum/stat/sampleuv"
+ "github.com/ollama/ollama/llama"
)
-type Sampler interface {
- Sample([]float32) (int32, error)
+// token represents information about a single token during sampling
+type token struct {
+ id int32 // The token's unique identifier
+ value float32 // The raw logit or probability from the model
}
-type weighted struct {
- src rand.Source
- transforms []Transform
+type Sampler struct {
+ rng *rand.Rand
+ topK int
+ topP float32
+ minP float32
+ temperature float32
+ grammar *Grammar
}
-// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
-func Weighted(seed *uint64, transforms ...Transform) Sampler {
- var src rand.Source
- if seed != nil {
- src = rand.NewSource(*seed)
- }
- return weighted{src: src, transforms: transforms}
-}
-
-func (s weighted) Sample(logits []float32) (int32, error) {
- logits64 := make([]float64, len(logits))
- for i, v := range logits {
- logits64[i] = float64(v)
- }
-
- for _, t := range s.transforms {
- logits64 = t.Apply(logits64)
- }
-
- logitsCopy := make([]float64, 0, len(logits))
- indices := make([]int, 0, len(logits))
- for i, logit := range logits64 {
- if !math.IsInf(logit, -1) {
- logitsCopy = append(logitsCopy, logit)
- indices = append(indices, i)
- }
- }
-
- if len(logitsCopy) == 0 {
- return -1, errors.New("no valid logits found for weighed sampling")
- }
-
- probs := softmax(logitsCopy)
- w := sampleuv.NewWeighted(probs, s.src)
- if idx, ok := w.Take(); ok {
- return int32(indices[idx]), nil
- }
- return -1, errors.New("weighted sampler failed, no valid token found")
-}
-
-type greedy struct{}
-
-func Greedy() Sampler {
- return greedy{}
-}
-
-// Sample returns the index of the maximum value in logits.
-func (s greedy) Sample(logits []float32) (int32, error) {
- if len(logits) == 0 {
- return -1, errors.New("no logits provided for greedy sampling")
- }
-
- maxIdx := 0
+func (s *Sampler) Sample(logits []float32) (int32, error) {
+ tokens := make([]token, len(logits))
for i := range logits {
- if logits[i] > logits[maxIdx] {
- maxIdx = i
+ tokens[i].id = int32(i)
+ tokens[i].value = logits[i]
+ }
+
+ t, err := s.sample(tokens)
+ if err != nil {
+ return -1, err
+ }
+
+ if s.grammar != nil {
+ // optimization: first check if the max logit is accepted by the grammar
+ // if the max logit is rejected, apply the grammar to all logits (slower)
+ top := []token{t}
+ s.grammar.Apply(top)
+ if !math.IsInf(float64(top[0].value), -1) {
+ s.grammar.Accept(top[0].id)
+ return top[0].id, nil
+ }
+
+ // since .sample has side effects of modifying the tokens
+ // we need to reset them before applying the grammar and
+ // sampling again
+ for i := range logits {
+ tokens[i].id = int32(i)
+ tokens[i].value = logits[i]
+ }
+ s.grammar.Apply(tokens)
+ t, err = s.sample(tokens)
+ if err != nil {
+ return -1, err
+ }
+ s.grammar.Accept(t.id)
+ }
+
+ return t.id, nil
+}
+
+// greedy returns the highest probability token from the tokens
+func greedy(tokens []token) token {
+ max := tokens[0]
+ for i := 1; i < len(tokens); i++ {
+ if tokens[i].value > max.value {
+ max = tokens[i]
}
}
- return int32(maxIdx), nil
+ return max
+}
+
+// sample returns the highest probability token from the tokens
+// given sampler parameters. It also has side effects of modifying the tokens
+func (s *Sampler) sample(tokens []token) (token, error) {
+ if s.temperature == 0 {
+ return greedy(tokens), nil
+ }
+
+ if s.topK > 0 {
+ tokens = topK(tokens, s.topK)
+ } else {
+ sortLogits(tokens)
+ }
+
+ // token logit values are updated to probabilities
+ tokens = temperature(tokens, s.temperature)
+
+ tokens = topP(tokens, s.topP)
+ tokens = minP(tokens, s.minP)
+
+ // TODO: this should fall back to greedy sampling
+ // or topP, topK values etc should be such that
+ // there are always tokens to sample from
+ if len(tokens) == 0 {
+ return token{}, errors.New("no tokens to sample from")
+ }
+
+ var r float32
+ if s.rng != nil {
+ r = s.rng.Float32()
+ } else {
+ r = rand.Float32()
+ }
+
+ // Calculate cumulative sum of probabilities
+ var sum float32
+ for i := range tokens {
+ sum += tokens[i].value
+ tokens[i].value = sum
+ }
+ r *= tokens[len(tokens)-1].value
+
+ idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
+ if token.value < target {
+ return -1
+ }
+ return 1
+ })
+
+ return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
-func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
- if temperature == 0 {
- return Greedy(), nil
+func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
+ var rng *rand.Rand
+ if seed != -1 {
+ // PCG requires two parameters: sequence and stream
+ // Use original seed for sequence
+ sequence := uint64(seed)
+ // Use golden ratio hash to generate statistically independent seeds
+ rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
+ }
+ if temperature < 0.0 {
+ temperature = 0.0
}
- if temperature < 0 || temperature > 2 {
- return nil, errors.New("temperature must be between 0 and 2")
+ if topP < 0.0 {
+ topP = 0.0
+ }
+ if topP >= 1.0 {
+ topP = 1.0
}
- transforms := []Transform{Temperature(temperature)}
-
- if topK != 0 {
- if topK <= 0 {
- return nil, errors.New("topK must be greater than 0")
- }
- transforms = append(transforms, TopK(topK))
+ if minP < 0.0 {
+ minP = 0.0
+ }
+ if minP >= 1.0 {
+ minP = 1.0
}
- if topP != 0 {
- if topP < 0 || topP >= 1 {
- return nil, errors.New("topP must be between 0 and 1")
- }
- transforms = append(transforms, TopP(topP))
+ return Sampler{
+ rng: rng,
+ topK: topK,
+ topP: topP,
+ minP: minP,
+ temperature: temperature,
+ grammar: grammar,
}
-
- if minP != 0 {
- if minP < 0 || minP >= 1 {
- return nil, errors.New("minP must be between 0 and 1")
- }
- transforms = append(transforms, MinP(minP))
- }
-
- if seed >= 0 {
- seed64 := uint64(seed)
- return Weighted(&seed64, transforms...), nil
- }
- return Weighted(nil, transforms...), nil
+}
+
+type Grammar struct {
+ vocab *Vocab
+ grammar string
+ sampler *llama.Sampler
+}
+
+func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
+ v, err := vocab.Load()
+ if err != nil {
+ return nil, err
+ }
+
+ return &Grammar{
+ vocab: vocab,
+ grammar: grammar,
+ sampler: llama.NewGrammarSampler(v, grammar),
+ }, nil
+}
+
+func (g *Grammar) Apply(tokens []token) {
+ tds := make([]llama.TokenData, len(tokens))
+ for i, token := range tokens {
+ tds[i].Id = token.id
+ tds[i].Logit = token.value
+ }
+
+ g.sampler.Apply(tds)
+
+ for i := range tokens {
+ tokens[i].value = tds[i].Logit
+ }
+}
+
+func (g *Grammar) Accept(token int32) {
+ g.sampler.Accept(token)
+}
+
+type Vocab struct {
+ once sync.Once
+ vocab *llama.Vocab
+ err error
+ path string
+}
+
+func NewVocab(path string) *Vocab {
+ return &Vocab{path: path}
+}
+
+// Load returns the lazily-loaded vocabulary
+func (v *Vocab) Load() (*llama.Vocab, error) {
+ v.once.Do(func() {
+ vocab, err := llama.LoadVocabFromFile(v.path)
+ if err != nil {
+ v.err = err
+ return
+ }
+ v.vocab = vocab
+ })
+ return v.vocab, v.err
}
diff --git a/sample/samplers_benchmark_test.go b/sample/samplers_benchmark_test.go
new file mode 100644
index 000000000..cd1380141
--- /dev/null
+++ b/sample/samplers_benchmark_test.go
@@ -0,0 +1,92 @@
+package sample
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkWeightedSampler(b *testing.B) {
+ sizes := []int{10, 100, 1000, 10000}
+
+ for _, size := range sizes {
+ b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
+ logits := make([]float32, size)
+ for i := range logits {
+ logits[i] = float32(rand.Float64()*10 - 5)
+ }
+
+ sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
+ b.ResetTimer()
+ for b.Loop() {
+ sampler.Sample(logits)
+ }
+ })
+ }
+
+ configs := []struct {
+ name string
+ temperature float32
+ topK int
+ topP float32
+ minP float32
+ seed int
+ }{
+ {"Greedy", 0, -1, 0, 0, -1},
+ {"Temperature", 0.8, -1, 0, 0, -1},
+ {"TopK", 0.8, 50, 0, 0, -1},
+ {"TopP", 0.8, -1, 0.9, 0, -1},
+ {"MinP", 0.8, -1, 0, 0.05, -1},
+ {"WithSeed", 0.8, 50, 0, 0, 42},
+ }
+
+ // Fixed size for common vocab size
+ size := 128000
+ logits := make([]float32, size)
+ for i := range logits {
+ logits[i] = float32(rand.Float64()*10 - 5)
+ }
+
+ for _, tc := range configs {
+ b.Run("Config"+tc.name, func(b *testing.B) {
+ sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
+ sampler.Sample(logits)
+
+ b.ResetTimer()
+
+ for b.Loop() {
+ sampler.Sample(logits)
+ }
+ })
+ }
+
+ // Test with combined transforms separately - topK influences performance greatly
+ b.Run("TransformCombined", func(b *testing.B) {
+ sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
+ b.ResetTimer()
+
+ for b.Loop() {
+ sampler.Sample(logits)
+ }
+ })
+}
+
+func BenchmarkGreedySampler(b *testing.B) {
+ sizes := []int{10, 100, 1000, 10000, 100000}
+
+ for _, size := range sizes {
+ b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
+ logits := make([]float32, size)
+ for i := range logits {
+ logits[i] = float32(rand.Float64()*10 - 5)
+ }
+
+ sampler := NewSampler(0, -1, 0, 0, -1, nil)
+ b.ResetTimer()
+
+ for b.Loop() {
+ sampler.Sample(logits)
+ }
+ })
+ }
+}
diff --git a/sample/samplers_test.go b/sample/samplers_test.go
index 32364a3b7..38b9b352a 100644
--- a/sample/samplers_test.go
+++ b/sample/samplers_test.go
@@ -1,15 +1,14 @@
package sample
import (
- "math"
"math/rand/v2"
"testing"
-
- "github.com/google/go-cmp/cmp"
)
func TestWeighted(t *testing.T) {
- got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
+ logits := []float32{-10, 3, -10, -10}
+ sampler := NewSampler(0, 0, 0, 0, 0, nil)
+ got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
return
@@ -19,194 +18,26 @@ func TestWeighted(t *testing.T) {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
- got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
- if err == nil {
- t.Error("expected error for no valid tokens, got index", got)
- }
-
- seed := uint64(42)
- got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
+ logits = []float32{-100, -10, 0, 10}
+ sampler = NewSampler(0, 0, 0, 0, 0, nil)
+ got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
- // With seed 42, we expect a consistent sample
- want = int32(3) // This will be deterministic due to the seed
+ want = int32(3) // Should pick highest probability with this r value
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
}
-type testTransform struct {
- id int
- callOrder *[]int
-}
-
-func (ts *testTransform) Apply(logits []float64) []float64 {
- if ts.callOrder != nil {
- *ts.callOrder = append(*ts.callOrder, ts.id)
- }
- return logits
-}
-
-func TestSample(t *testing.T) {
- input := []float32{1, 2, 3, 4}
-
- var callOrder []int
- mock1 := &testTransform{
- id: 1,
- callOrder: &callOrder,
- }
- mock2 := &testTransform{
- id: 2,
- callOrder: &callOrder,
- }
- mock3 := &testTransform{
- id: 3,
- callOrder: &callOrder,
- }
-
- _, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
- if err != nil {
- t.Error(err)
- return
- }
- wantOrder := []int{1, 2, 3}
- if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
- t.Errorf("call order mismatch (-want +got):\n%s", diff)
- }
-}
-
-func TestNewSampler(t *testing.T) {
- tests := []struct {
- name string
- temperature float32
- topK int
- topP float32
- minP float32
- seed int
- wantErr bool
- }{
- {
- name: "no transforms",
- // temperature is 0, so greedy should be used
- wantErr: false,
- },
- {
- name: "temperature",
- temperature: 0.5,
- wantErr: false,
- },
- {
- name: "invalid temperature negative",
- temperature: -1,
- wantErr: true,
- },
- {
- name: "invalid temperature too high",
- temperature: 2.1,
- wantErr: true,
- },
- {
- name: "top k",
- topK: 10,
- temperature: 0.8,
- wantErr: false,
- },
- {
- name: "invalid top k negative",
- topK: -1,
- temperature: 0.8,
- wantErr: true,
- },
- {
- name: "top p",
- topP: 0.9,
- temperature: 0.8,
- wantErr: false,
- },
- {
- name: "invalid top p negative",
- topP: -0.1,
- temperature: 0.8,
- wantErr: true,
- },
- {
- name: "invalid top p one",
- topP: 1.0,
- temperature: 0.8,
- wantErr: true,
- },
- {
- name: "min p",
- minP: 0.2,
- temperature: 0.8,
- wantErr: false,
- },
- {
- name: "invalid min p negative",
- minP: -0.1,
- temperature: 0.8,
- wantErr: true,
- },
- {
- name: "invalid min p one",
- minP: 1.0,
- temperature: 0.8,
- wantErr: true,
- },
- {
- name: "default values",
- temperature: 0.8,
- topK: 40,
- topP: 0.9,
- minP: 0.0,
- seed: 0,
- wantErr: false,
- },
- {
- name: "all zeroes",
- temperature: 0.0,
- topK: 0,
- topP: 0.0,
- minP: 0.0,
- seed: 0,
- wantErr: false, // all zeroes means no transforms
- },
- {
- name: "all transforms",
- temperature: 0.8,
- topK: 50,
- topP: 0.95,
- minP: 0.1,
- seed: 42,
- wantErr: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
- if (err != nil) != tt.wantErr {
- t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
func BenchmarkSample(b *testing.B) {
- transforms := []Transform{
- Temperature(0.5),
- TopK(10),
- TopP(0.9),
- MinP(0.2),
- }
-
samplers := map[string]Sampler{
- "Greedy": Greedy(),
- "Weighted": Weighted(nil, transforms...),
+ "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
+ "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
}
+ // Generate random logits for benchmarking
logits := make([]float32, 1<<16)
for i := range logits {
logits[i] = rand.Float32()
@@ -215,9 +46,9 @@ func BenchmarkSample(b *testing.B) {
for name, s := range samplers {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
- for range b.N {
+ for b.Loop() {
if _, err := s.Sample(logits); err != nil {
- b.Error(err)
+ b.Fatalf("error sampling: %v", err)
}
}
})
diff --git a/sample/transforms.go b/sample/transforms.go
index 2dc6ebae1..ab62455f3 100644
--- a/sample/transforms.go
+++ b/sample/transforms.go
@@ -1,120 +1,195 @@
package sample
import (
- "cmp"
"math"
"slices"
-
- pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
-type Transform interface {
- Apply([]float64) []float64
-}
-
-// TODO(parthsareen): potentially cache softmax values
-func softmax(logits []float64) []float64 {
- var sum float64
- probs := make([]float64, len(logits))
- for i, v := range logits {
- probs[i] = math.Exp(v)
- sum += probs[i]
- }
-
- for i := range probs {
- probs[i] /= sum
- }
-
- return probs
-}
-
-type Temperature float64
-
-func (t Temperature) Apply(logits []float64) []float64 {
- temp := math.Max(float64(t), 1e-7)
-
- // subtracting max logit to avoid under/overflow
- maxLogit := slices.Max(logits)
- for i := range logits {
- logits[i] = (logits[i] - maxLogit) / temp
- }
-
- return logits
-}
-
-type logitMap struct {
- index int
- logit float64
-}
-
-type TopK int
-
-// TODO(parthsareen): avoid having to check all logits after this transform
-func (k TopK) Apply(logits []float64) []float64 {
- if int(k) >= len(logits) {
- return logits
- }
- q := pq.NewWith(func(a, b logitMap) int {
- return -cmp.Compare(a.logit, b.logit)
- })
-
- for i, logit := range logits {
- q.Enqueue(logitMap{index: i, logit: logit})
- }
-
- validLogits := make(map[int]float64)
- for range k {
- logitMap, _ := q.Dequeue()
- validLogits[logitMap.index] = logitMap.logit
- }
-
- for i := range logits {
- if _, ok := validLogits[i]; !ok {
- logits[i] = math.Inf(-1)
+// temperature applies scaling and softmax to the logits
+func temperature(ts []token, temp float32) []token {
+ // Find max logit for numerical stability
+ maxLogit := float32(math.Inf(-1))
+ for _, t := range ts {
+ if t.value > maxLogit {
+ maxLogit = t.value
}
}
- return logits
-}
-
-type TopP float64
-
-func (p TopP) Apply(logits []float64) []float64 {
- probs := softmax(logits)
- indices := make([]int, len(probs))
- for i := range indices {
- indices[i] = i
+ // Apply temperature and compute exp(x - max)
+ temp = max(temp, 1e-7)
+ var sum float32
+ for i, v := range ts {
+ ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
+ sum += ts[i].value
}
- // sort in descending order
- slices.SortFunc(indices, func(i, j int) int {
- return cmp.Compare(probs[j], probs[i])
- })
+ // Normalize
+ for i := range ts {
+ ts[i].value /= sum
+ }
- var sum float64
- for i, idx := range indices {
- sum += probs[idx]
- if sum > float64(p) {
- for _, idx := range indices[i+1:] {
- logits[idx] = math.Inf(-1)
- }
+ return ts
+}
+
+// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
+//
+// The heap is represented as an array where for any node at index i:
+// - Left child is at index 2i + 1
+// - Right child is at index 2i + 2
+// - Parent is at index (i-1)/2
+//
+// The function compares a node with its children and:
+// 1. Finds the smallest value between the node and its children
+// 2. If the node is not the smallest, swaps it with its smallest child
+// 3. Continues this process down the affected path until the min-heap property is restored
+func siftDown(data []token, start, end int) {
+ root := start
+ for {
+ child := 2*root + 1
+ if child >= end {
break
}
+ // Find smaller child (we want min heap)
+ if child+1 < end && data[child+1].value < data[child].value {
+ child++
+ }
+ // Exit if root is already smaller than children
+ if data[root].value <= data[child].value {
+ break
+ }
+ // Swap with smaller child and continue
+ data[root], data[child] = data[child], data[root]
+ root = child
}
- return logits
}
-type MinP float64
+// topK limits the number of tokens considered to the k highest logits
+func topK(ts []token, k int) []token {
+ if k >= len(ts) {
+ return ts
+ }
+ // Heapify + siftDown - O(nlog(k))
+ // Build min-heap of first k elements
+ heap := ts[:k]
+ for i := k/2 - 1; i >= 0; i-- {
+ siftDown(heap, i, k)
+ }
-func (p MinP) Apply(logits []float64) []float64 {
- probs := softmax(logits)
- threshold := slices.Max(probs) * float64(p)
-
- for i, prob := range probs {
- if prob < threshold {
- logits[i] = math.Inf(-1)
+ // Process remaining elements - if larger than heap root, replace root
+ for i := k; i < len(ts); i++ {
+ if ts[i].value > heap[0].value {
+ heap[0] = ts[i]
+ siftDown(heap, 0, k)
}
}
- return logits
+ slices.Reverse(heap)
+
+ ts = heap
+ return ts
+}
+
+// topP limits tokens to those with cumulative probability p
+func topP(ts []token, p float32) []token {
+ if p == 1.0 {
+ return ts
+ }
+
+ // Find cutoff index where cumulative sum exceeds p
+ var sum float32
+ for i, t := range ts {
+ sum += t.value
+ if sum > float32(p) {
+ ts = ts[:i+1]
+ return ts
+ }
+ }
+
+ return ts
+}
+
+// minP limits tokens to those with cumulative probability p
+func minP(ts []token, p float32) []token {
+ if p == 1.0 {
+ return ts
+ }
+
+ maxProb := float32(math.Inf(-1))
+ for _, token := range ts {
+ if token.value > maxProb {
+ maxProb = token.value
+ }
+ }
+
+ threshold := maxProb * float32(p)
+
+ // Filter tokens in-place
+ validTokens := ts[:0]
+ for i, token := range ts {
+ if token.value >= threshold {
+ validTokens = append(validTokens, ts[i])
+ }
+ }
+
+ ts = validTokens
+ return ts
+}
+
+// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
+// sortLogits sorts implementation to sort tokens by logits using counting sort
+// counting sort is faster than built-in sort for this use case
+func sortLogits(tokens []token) {
+ if len(tokens) <= 1 {
+ return
+ }
+
+ // Find max/min in a single pass
+ minLogit, maxLogit := tokens[0].value, tokens[0].value
+ for _, t := range tokens[1:] {
+ if t.value < minLogit {
+ minLogit = t.value
+ } else if t.value > maxLogit {
+ maxLogit = t.value
+ }
+ }
+
+ // Calculate scaling to map to uint32 range
+ logitRange := maxLogit - minLogit
+ if logitRange < 1e-6 {
+ return // All values effectively equal
+ }
+
+ // Count frequencies directly from tokens
+ const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
+ var counts [256]int // For first byte
+
+ // First pass: count frequencies
+ for _, t := range tokens {
+ // Map to [0, maxInt] range
+ score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
+ counts[score>>16]++
+ }
+
+ // Calculate offsets
+ var offset int
+ for i := range counts {
+ count := counts[i]
+ counts[i] = offset
+ offset += count
+ }
+
+ // Second pass: place elements in correct position
+ output := make([]token, len(tokens))
+ // Track current positions
+ countsCopy := counts
+
+ for i, t := range tokens {
+ score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
+
+ pos := countsCopy[score>>16]
+ countsCopy[score>>16]++
+ output[len(tokens)-1-pos] = tokens[i]
+ }
+
+ copy(tokens, output)
}
diff --git a/sample/transforms_test.go b/sample/transforms_test.go
index 05f76a274..81e8849b7 100644
--- a/sample/transforms_test.go
+++ b/sample/transforms_test.go
@@ -4,77 +4,175 @@ import (
"math"
"math/rand/v2"
"testing"
-
- "github.com/google/go-cmp/cmp"
)
-func TestTemperature(t *testing.T) {
- got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
- want := []float64{-4, -10, 0, -14, -6, -12, -8}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
+// Helper to convert float64 slice to logit slice
+func toTokens(values []float64) []token {
+ tokens := make([]token, len(values))
+ for i, v := range values {
+ tokens[i] = token{
+ id: int32(i),
+ value: float32(v),
+ }
+ }
+ return tokens
+}
+
+// Helper to compare logit slices
+func compareLogits(t *testing.T, name string, want []float64, got []token) {
+ t.Helper()
+ if len(want) != len(got) {
+ t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
+ return
+ }
+ for i := range want {
+ if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
+ t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
+ }
}
}
-func TestSoftmax(t *testing.T) {
- got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
+func TestTemperatureAndSoftmax(t *testing.T) {
+ input := []float64{1, 4, -2, 0}
+ got := temperature(toTokens(input), 0.5)
- want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("probs mismatch (-want +got):\n%s", diff)
+ // Check probabilities sum to 1
+ var sum float32
+ for _, token := range got {
+ sum += token.value
+ }
+ if math.Abs(float64(sum)-1.0) > 1e-6 {
+ t.Errorf("probabilities don't sum to 1: got %f", sum)
+ }
+
+ got = temperature(toTokens(input), 1)
+ // Check probabilities sum to 1
+ sum = 0.0
+ for _, token := range got {
+ sum += token.value
+ }
+ if math.Abs(float64(sum)-1.0) > 1e-6 {
+ t.Errorf("probabilities don't sum to 1: got %f", sum)
}
}
func TestTopK(t *testing.T) {
- got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
- }
+ input := []float64{-3, -2, -1, 0, 1, 2, 4}
- got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
-
- want = []float64{-3, -2, -1, 0, 1, 2, 4}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
+ // Test k=3
+ got := topK(toTokens(input), 3)
+ if len(got) != 3 {
+ t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
}
+ // Should keep highest 3 values: 4, 2, 1
+ want := []float64{4, 2, 1}
+ compareLogits(t, "topK(3)", want, got)
+
+ // Test k > len
+ got = topK(toTokens(input), 10)
+ compareLogits(t, "topK(10)", input, got)
}
func TestTopP(t *testing.T) {
- got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
+ input := []float64{-3, -2, -1, 0, 1, 2, 4}
+ tokens := toTokens(input)
+
+ // First apply temperature and softmax to get probabilities
+ tokens = temperature(tokens, 1)
+ sortLogits(tokens)
+
+ // Then apply topP
+ got := topP(tokens, 0.95)
+
+ // Should keep tokens until cumsum > 0.95
+ if len(got) > 3 {
+ t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
+ t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
- got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
- want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("logits mismatch (-want +got):\n%s", diff)
+ input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
+ tokens := toTokens(input)
+
+ // First apply temperature and softmax
+ tokens = temperature(tokens, 1)
+
+ // Then apply minP
+ got := minP(tokens, 0.2)
+
+ // Should keep tokens with prob >= 0.2 * max_prob
+ if len(got) > 3 {
+ t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
}
}
-func BenchmarkTransform(b *testing.B) {
- transforms := map[string]Transform{
- "Temperature": Temperature(0.5),
- "TopK": TopK(10),
- "TopP": TopP(0.9),
- "MinP": MinP(0.2),
+func TestSortLogits(t *testing.T) {
+ input := []float64{3, 1, 4, 2, -1, 0, -2}
+ tokens := toTokens(input)
+
+ sortLogits(tokens)
+
+ for i := 1; i < len(tokens); i++ {
+ if tokens[i].value > tokens[i-1].value {
+ t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
+ i, tokens[i].value, tokens[i-1].value)
+ }
}
- logits := make([]float64, 1<<16)
- for i := range logits {
- logits[i] = rand.Float64()
- }
-
- for name, transform := range transforms {
- b.Run(name, func(b *testing.B) {
- b.ResetTimer()
- for range b.N {
- transform.Apply(logits)
- }
- })
- }
+ want := []float64{4, 3, 2, 1, 0, -1, -2}
+ compareLogits(t, "sortLogits", want, tokens)
+}
+
+func BenchmarkTransforms(b *testing.B) {
+ // Generate random logits
+ tokens := make([]token, 1<<16)
+ for i := range tokens {
+ tokens[i] = token{
+ id: int32(i),
+ value: rand.Float32(),
+ }
+ }
+
+ tokensCopy := make([]token, len(tokens))
+
+ b.Run("Temperature", func(b *testing.B) {
+ b.ResetTimer()
+ for b.Loop() {
+ copy(tokensCopy, tokens)
+ temperature(tokensCopy, 0.5)
+ }
+ })
+
+ b.Run("TopK", func(b *testing.B) {
+ b.ResetTimer()
+ for b.Loop() {
+ copy(tokensCopy, tokens)
+ topK(tokensCopy, 10)
+ }
+ })
+
+ b.Run("TopP", func(b *testing.B) {
+ b.ResetTimer()
+ for b.Loop() {
+ copy(tokensCopy, tokens)
+ topP(tokensCopy, 0.9)
+ }
+ })
+
+ b.Run("MinP", func(b *testing.B) {
+ b.ResetTimer()
+ for b.Loop() {
+ copy(tokensCopy, tokens)
+ minP(tokensCopy, 0.2)
+ }
+ })
+
+ b.Run("SortTokens", func(b *testing.B) {
+ b.ResetTimer()
+ for b.Loop() {
+ copy(tokensCopy, tokens)
+ sortLogits(tokensCopy)
+ }
+ })
}
diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1
index 62930d7f2..60485df85 100644
--- a/scripts/build_windows.ps1
+++ b/scripts/build_windows.ps1
@@ -80,13 +80,14 @@ function checkEnv() {
function buildOllama() {
+ mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
- & cmake --build --preset CPU --parallel $script:JOBS
+ & cmake --build --preset CPU --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -101,7 +102,7 @@ function buildOllama() {
# to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
- & cmake --build --preset "CUDA 11" --parallel $script:JOBS
+ & cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -112,7 +113,7 @@ function buildOllama() {
write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
- & cmake --build --preset "CUDA 12" --parallel $script:JOBS
+ & cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -131,7 +132,7 @@ function buildOllama() {
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
- & cmake --build --preset "ROCm" --parallel $script:JOBS
+ & cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
diff --git a/scripts/install.sh b/scripts/install.sh
index 9e146e508..9c232400f 100644
--- a/scripts/install.sh
+++ b/scripts/install.sh
@@ -77,11 +77,12 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then
fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
-$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR"
+$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
+
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"
diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go
index e4c36d7d8..423a6ad23 100644
--- a/server/internal/client/ollama/registry.go
+++ b/server/internal/client/ollama/registry.go
@@ -24,8 +24,10 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strconv"
"strings"
+ "sync"
"sync/atomic"
"time"
@@ -43,9 +45,9 @@ import (
// Errors
var (
- // ErrManifestNotFound is returned when a manifest is not found in the
+ // ErrModelNotFound is returned when a manifest is not found in the
// cache or registry.
- ErrManifestNotFound = errors.New("manifest not found")
+ ErrModelNotFound = errors.New("model not found")
// ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid.
@@ -53,7 +55,7 @@ var (
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
- ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
+ ErrNameInvalid = errors.New("invalid or missing name")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
@@ -72,19 +74,22 @@ const (
DefaultMaxChunkSize = 8 << 20
)
-// DefaultCache returns a new disk cache for storing models. If the
-// OLLAMA_MODELS environment variable is set, it uses that directory;
-// otherwise, it uses $HOME/.ollama/models.
-func DefaultCache() (*blob.DiskCache, error) {
+var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
- home, err := os.UserHomeDir()
- if err != nil {
- return nil, err
- }
+ home, _ := os.UserHomeDir()
+ home = cmp.Or(home, ".")
dir = filepath.Join(home, ".ollama", "models")
}
return blob.Open(dir)
+})
+
+// DefaultCache returns the default cache used by the registry. It is
+// configured from the OLLAMA_MODELS environment variable, or defaults to
+// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
+// it uses the current working directory.
+func DefaultCache() (*blob.DiskCache, error) {
+ return defaultCache()
}
// Error is the standard error returned by Ollama APIs. It can represent a
@@ -109,7 +114,18 @@ type Error struct {
}
func (e *Error) Error() string {
- return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
+ var b strings.Builder
+ b.WriteString("registry responded with status ")
+ b.WriteString(strconv.Itoa(e.Status))
+ if e.Code != "" {
+ b.WriteString(": code ")
+ b.WriteString(e.Code)
+ }
+ if e.Message != "" {
+ b.WriteString(": ")
+ b.WriteString(e.Message)
+ }
+ return b.String()
}
func (e *Error) LogValue() slog.Value {
@@ -167,6 +183,10 @@ func CompleteName(name string) string {
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
+ // Cache is the cache used to store models. If nil, [DefaultCache] is
+ // used.
+ Cache *blob.DiskCache
+
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string
@@ -205,10 +225,28 @@ type Registry struct {
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
- // NameMask, if set, is the name used to convert non-fully qualified
- // names to fully qualified names. If empty, the default mask
- // ("registry.ollama.ai/library/_:latest") is used.
- NameMask string
+ // Mask, if set, is the name used to convert non-fully qualified names
+ // to fully qualified names. If empty, [DefaultMask] is used.
+ Mask string
+}
+
+func (r *Registry) cache() (*blob.DiskCache, error) {
+ if r.Cache != nil {
+ return r.Cache, nil
+ }
+ return defaultCache()
+}
+
+func (r *Registry) parseName(name string) (names.Name, error) {
+ mask := defaultMask
+ if r.Mask != "" {
+ mask = names.Parse(r.Mask)
+ }
+ n := names.Merge(names.Parse(name), mask)
+ if !n.IsFullyQualified() {
+ return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
+ }
+ return n, nil
}
// DefaultRegistry returns a new Registry configured from the environment. The
@@ -243,52 +281,6 @@ func DefaultRegistry() (*Registry, error) {
return &rc, nil
}
-type PushParams struct {
- // From is an optional destination name for the model. If empty, the
- // destination name is the same as the source name.
- From string
-}
-
-// parseName parses name using [names.ParseExtended] and then merges the name with the
-// default name, and checks that the name is fully qualified. If a digest is
-// present, it parse and returns it with the other fields as their zero values.
-//
-// It returns an error if the name is not fully qualified, or if the digest, if
-// any, is invalid.
-//
-// The scheme is returned as provided by [names.ParseExtended].
-func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
- maskName := defaultMask
- if mask != "" {
- maskName = names.Parse(mask)
- if !maskName.IsFullyQualified() {
- return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
- }
- }
- scheme, n, ds := names.ParseExtended(s)
- if !n.IsValid() {
- return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
- }
- n = names.Merge(n, maskName)
- if ds != "" {
- // Digest is present. Validate it.
- d, err = blob.ParseDigest(ds)
- if err != nil {
- return "", names.Name{}, blob.Digest{}, err
- }
- }
-
- // The name check is deferred until after the digest check because we
- // say that digests take precedence over names, and so should there
- // errors when being parsed.
- if !n.IsFullyQualified() {
- return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
- }
-
- scheme = cmp.Or(scheme, "https")
- return scheme, n, d, nil
-}
-
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
@@ -308,13 +300,24 @@ func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
+type PushParams struct {
+ // From is an optional destination name for the model. If empty, the
+ // destination name is the same as the source name.
+ From string
+}
+
// Push pushes the model with the name in the cache to the remote registry.
-func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
+func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
- m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
+ c, err := r.cache()
+ if err != nil {
+ return err
+ }
+
+ m, err := r.ResolveLocal(cmp.Or(p.From, name))
if err != nil {
return err
}
@@ -337,7 +340,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
- scheme, n, _, err := parseName(name, r.NameMask)
+ scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@@ -363,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
n.Model(),
l.Digest,
)
- res, err := r.doOK(ctx, "POST", startURL, nil)
+ res, err := r.send(ctx, "POST", startURL, nil)
if err != nil {
return err
}
@@ -387,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
}
req.ContentLength = l.Size
- res, err = doOK(r.client(), req)
+ res, err = sendRequest(r.client(), req)
if err == nil {
res.Body.Close()
}
@@ -407,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
n.Model(),
n.Tag(),
)
- res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
+ res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil {
res.Body.Close()
}
@@ -430,8 +433,8 @@ func canRetry(err error) bool {
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
-func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
- scheme, n, _, err := parseName(name, r.NameMask)
+func (r *Registry) Pull(ctx context.Context, name string) error {
+ scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
@@ -444,6 +447,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
+ c, err := r.cache()
+ if err != nil {
+ return err
+ }
+
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
@@ -451,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
t := traceFromContext(ctx)
- var g errgroup.Group
+ g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
- for _, l := range m.Layers {
+ layers := m.Layers
+ if m.Config != nil && m.Config.Digest.IsValid() {
+ layers = append(layers, m.Config)
+ }
+
+ for _, l := range layers {
if exists(l) {
t.update(l, l.Size, ErrCached)
continue
@@ -471,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error {
- res, err := doOK(r.client(), req)
+ // TODO(bmizerany): retry/backoff like below in
+ // the chunking case
+ res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
@@ -497,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
- res, err := doOK(r.client(), req)
+ res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
- streamNo := 0
- tws := make([]*bufio.Writer, r.maxStreams()-1)
+ wp := writerPool{size: r.maxChunkSize()}
+
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
+ if ctx.Err() != nil {
+ break
+ }
+
ticket := q.Take()
- bufIdx := streamNo % len(tws)
- streamNo++
g.Go(func() (err error) {
defer func() {
if err != nil {
@@ -523,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
if err != nil {
return err
}
-
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
- res, err := doOK(r.client(), req)
+ res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
- tw := tws[bufIdx]
- if tw == nil {
- tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
- tws[bufIdx] = tw
- }
+ tw := wp.get()
tw.Reset(ticket)
- defer tw.Reset(nil) // release ticket
+ defer wp.put(tw)
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
@@ -581,8 +593,12 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
-func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
- _, n, _, err := parseName(name, r.NameMask)
+func (r *Registry) Unlink(name string) (ok bool, _ error) {
+ n, err := r.parseName(name)
+ if err != nil {
+ return false, err
+ }
+ c, err := r.cache()
if err != nil {
return false, err
}
@@ -594,6 +610,9 @@ type Manifest struct {
Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"`
+
+ // For legacy reasons, we still have to download the config layer.
+ Config *Layer `json:"config"`
}
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
@@ -657,14 +676,18 @@ type Layer struct {
Size int64 `json:"size"`
}
-// ResolveLocal resolves a name to a Manifest in the local cache. The name is
-// parsed using [names.ParseExtended] but the scheme is ignored.
-func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
- _, n, d, err := parseName(name, r.NameMask)
+// ResolveLocal resolves a name to a Manifest in the local cache.
+func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
+ _, n, d, err := r.parseNameExtended(name)
+ if err != nil {
+ return nil, err
+ }
+ c, err := r.cache()
if err != nil {
return nil, err
}
if !d.IsValid() {
+ // No digest, so resolve the manifest by name.
d, err = c.Resolve(n.String())
if err != nil {
return nil, err
@@ -673,7 +696,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
data, err := os.ReadFile(c.GetFile(d))
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
- return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
+ return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
}
return nil, err
}
@@ -686,7 +709,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
- scheme, n, d, err := parseName(name, r.NameMask)
+ scheme, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@@ -696,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
}
- res, err := r.doOK(ctx, "GET", manifestURL, nil)
+ res, err := r.send(ctx, "GET", manifestURL, nil)
if err != nil {
return nil, err
}
@@ -721,7 +744,7 @@ func (r *Registry) client() *http.Client {
}
// newRequest constructs a new request, ready to use, with the given method,
-// url, and body, presigned with client Key and UserAgent.
+// url, and body, pre-signed with client [Key] and [UserAgent].
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
@@ -740,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
return req, nil
}
-// doOK makes a request with the given client and request, and returns the
+// sendRequest makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it
// is returned.
-func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
+func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
+ defer func() {
+ if err != nil {
+ err = fmt.Errorf("request error %s: %w", r.URL, err)
+ }
+ }()
+
if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc.
@@ -787,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
// Use the raw body if we can't parse it as an error object.
re.Message = string(out)
}
+
+ // coerce MANIFEST_UNKNOWN to ErrManifestNotFound
+ if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
+ return nil, ErrModelNotFound
+ }
+
re.Status = res.StatusCode
return nil, &re
}
return res, nil
}
-// doOK is a convenience method for making a request with newRequest and
-// passing it to doOK with r.client().
-func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
+// send is a convenience method for making a request with newRequest and
+// passing it to send with r.client().
+func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body)
if err != nil {
return nil, err
}
- return doOK(r.client(), req)
+ return sendRequest(r.client(), req)
}
// makeAuthToken creates an Ollama auth token for the given private key.
@@ -869,3 +904,114 @@ func maybeUnexpectedEOF(err error) error {
}
return err
}
+
+type publicError struct {
+ wrapped error
+ message string
+}
+
+func withPublicMessagef(err error, message string, args ...any) error {
+ return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
+}
+
+func (e publicError) Error() string { return e.message }
+func (e publicError) Unwrap() error { return e.wrapped }
+
+var supportedSchemes = []string{
+ "http",
+ "https",
+ "https+insecure",
+}
+
+var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
+
+// parseNameExtended parses and validates an extended name, returning the scheme, name,
+// and digest.
+//
+// If the scheme is empty, scheme will be "https". If an unsupported scheme is
+// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
+//
+// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
+// message is returned.
+//
+// If the name is not, once merged with the mask, fully qualified,
+// [ErrNameInvalid] wrapped with a display friendly message is returned.
+func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
+ scheme, name, digest := splitExtended(s)
+ scheme = cmp.Or(scheme, "https")
+ if !slices.Contains(supportedSchemes, scheme) {
+ err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
+ return "", names.Name{}, blob.Digest{}, err
+ }
+
+ var d blob.Digest
+ if digest != "" {
+ var err error
+ d, err = blob.ParseDigest(digest)
+ if err != nil {
+ err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
+ return "", names.Name{}, blob.Digest{}, err
+ }
+ if name == "" {
+ // We have can resolve a manifest from a digest only,
+ // so skip name validation and return the scheme and
+ // digest.
+ return scheme, names.Name{}, d, nil
+ }
+ }
+
+ n, err := r.parseName(name)
+ if err != nil {
+ return "", names.Name{}, blob.Digest{}, err
+ }
+ return scheme, n, d, nil
+}
+
+// splitExtended splits an extended name string into its scheme, name, and digest
+// parts.
+//
+// Examples:
+//
+// http://ollama.com/bmizerany/smol:latest@digest
+// https://ollama.com/bmizerany/smol:latest
+// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
+// model@digest
+// @digest
+func splitExtended(s string) (scheme, name, digest string) {
+ i := strings.Index(s, "://")
+ if i >= 0 {
+ scheme = s[:i]
+ s = s[i+3:]
+ }
+ i = strings.LastIndex(s, "@")
+ if i >= 0 {
+ digest = s[i+1:]
+ s = s[:i]
+ }
+ return scheme, s, digest
+}
+
+type writerPool struct {
+ size int64 // set by the caller
+
+ mu sync.Mutex
+ ws []*bufio.Writer
+}
+
+func (p *writerPool) get() *bufio.Writer {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if len(p.ws) == 0 {
+ return bufio.NewWriterSize(nil, int(p.size))
+ }
+ w := p.ws[len(p.ws)-1]
+ p.ws = p.ws[:len(p.ws)-1]
+ return w
+}
+
+func (p *writerPool) put(w *bufio.Writer) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ w.Reset(nil)
+ p.ws = append(p.ws, w)
+}
diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go
index af898c268..8f4e1604f 100644
--- a/server/internal/client/ollama/registry_test.go
+++ b/server/internal/client/ollama/registry_test.go
@@ -2,6 +2,7 @@ package ollama
import (
"bytes"
+ "cmp"
"context"
"encoding/json"
"errors"
@@ -72,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
+
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
@@ -84,14 +86,15 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
}
- rc := &Registry{
+ r := &Registry{
+ Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
link := func(name string, manifest string) {
- _, n, _, err := parseName(name, rc.NameMask)
+ n, err := r.parseName(name)
if err != nil {
panic(err)
}
@@ -122,7 +125,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link("invalid", "!!!!!")
- return rc, c
+ return r, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
@@ -145,84 +148,61 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
return d
}
-func TestRegistryPushInvalidNames(t *testing.T) {
- rc, c := newClient(t, nil)
-
- cases := []struct {
- name string
- err error
- }{
- {"", ErrNameInvalid},
- {"@", ErrNameInvalid},
- {"@x", blob.ErrInvalidDigest},
- }
-
- for _, tt := range cases {
- t.Run(tt.name, func(t *testing.T) {
- // Create a new registry and push a new image.
- err := rc.Push(t.Context(), c, tt.name, nil)
- if !errors.Is(err, tt.err) {
- t.Errorf("err = %v; want %v", err, tt.err)
- }
- })
- }
-}
-
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
- rc, c := newClient(t, okHandler)
- err := rc.Push(t.Context(), c, "empty", nil)
+ rc, _ := newClient(t, okHandler)
+ err := rc.Push(t.Context(), "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
- rc, c := newClient(t, okHandler)
- err := rc.Push(t.Context(), c, "single", nil)
+ rc, _ := newClient(t, okHandler)
+ err := rc.Push(t.Context(), "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
- rc, c := newClient(t, okHandler)
- err := rc.Push(t.Context(), c, "multiple", nil)
+ rc, _ := newClient(t, okHandler)
+ err := rc.Push(t.Context(), "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
- err := rc.Push(t.Context(), c, "notfound", nil)
+ err := rc.Push(t.Context(), "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
- rc, c := newClient(t, nil)
- err := rc.Push(t.Context(), c, "null", nil)
+ rc, _ := newClient(t, nil)
+ err := rc.Push(t.Context(), "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
- rc, c := newClient(t, nil)
+ rc, _ := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
- got := rc.Push(ctx, c, "sizemismatch", nil)
+ got := rc.Push(ctx, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
- rc, c := newClient(t, nil)
- err := rc.Push(t.Context(), c, "invalid", nil)
+ rc, _ := newClient(t, nil)
+ err := rc.Push(t.Context(), "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
@@ -230,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
@@ -258,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
check := testutil.Checker(t)
- err := rc.Push(ctx, c, "single", nil)
+ err := rc.Push(ctx, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
- err = rc.Push(ctx, c, "single", nil)
+ err = rc.Push(ctx, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
- got := rc.Push(t.Context(), c, "single", nil)
+ got := rc.Push(t.Context(), "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
- got := rc.Push(t.Context(), c, "single", nil)
+ got := rc.Push(t.Context(), "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
@@ -294,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
}
func TestPushUploadRoundtripError(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
- got := rc.Push(t.Context(), c, "single", nil)
+ got := rc.Push(t.Context(), "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
@@ -317,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
os.Remove(c.GetFile(l.Digest))
},
})
- got := rc.Push(ctx, c, "single", nil)
+ got := rc.Push(ctx, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
- err := rc.Push(t.Context(), c, "zero", nil)
+ err := rc.Push(t.Context(), "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
@@ -344,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
}
func TestRegistryPullInvalidName(t *testing.T) {
- rc, c := newClient(t, nil)
- err := rc.Pull(t.Context(), c, "://")
+ rc, _ := newClient(t, nil)
+ err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
@@ -360,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
}
for _, resp := range cases {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
- err := rc.Pull(t.Context(), c, "x")
+ err := rc.Pull(t.Context(), "x")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
@@ -386,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
})
// Confirm that the layer does not exist locally
- _, err := rc.ResolveLocal(c, "model")
+ _, err := rc.ResolveLocal("model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
- err = rc.Pull(t.Context(), c, "model")
+ err = rc.Pull(t.Context(), "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
- mg, err := rc.ResolveLocal(c, "model")
+ mg, err := rc.ResolveLocal("model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -422,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
@@ -445,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
- err := rc.Pull(ctx, c, "single")
+ err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{6}
@@ -458,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
}
func TestRegistryPullManifestNotFound(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
- err := rc.Pull(t.Context(), c, "notfound")
+ err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
- err := rc.Pull(t.Context(), c, "single")
+ err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
- err := rc.Pull(t.Context(), c, "single")
+ err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
@@ -534,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
// Check that we pull all layers that we can.
- err := rc.Pull(ctx, c, "mixed")
+ err := rc.Pull(ctx, "mixed")
if err != nil {
t.Fatal(err)
}
@@ -552,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
}
func TestRegistryPullChunking(t *testing.T) {
- rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
@@ -590,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
},
})
- err := rc.Pull(ctx, c, "remote")
+ err := rc.Pull(ctx, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
@@ -622,13 +602,13 @@ func TestInsecureSkipVerify(t *testing.T) {
}))
defer s.Close()
- const name = "ollama.com/library/insecure"
+ const name = "library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
- t.Errorf("err = %v; want cert verifiction failure", err)
+ t.Errorf("err = %v; want cert verification failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
@@ -724,3 +704,115 @@ func TestErrorUnmarshal(t *testing.T) {
})
}
}
+
+// TestParseNameErrors tests that parseName returns errors messages with enough
+// detail for users to debug naming issues they may encounter. Previous to this
+// test, the error messages were not very helpful and each problem was reported
+// as the same message.
+//
+// It is only for testing error messages, not that all invalids and valids are
+// covered. Those are in other tests for names.Name and blob.Digest.
+func TestParseNameExtendedErrors(t *testing.T) {
+ cases := []struct {
+ name string
+ err error
+ want string
+ }{}
+
+ var r Registry
+ for _, tt := range cases {
+ _, _, _, err := r.parseNameExtended(tt.name)
+ if !errors.Is(err, tt.err) {
+ t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
+ }
+ if err != nil && !strings.Contains(err.Error(), tt.want) {
+ t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
+ }
+ }
+}
+
+func TestParseNameExtended(t *testing.T) {
+ cases := []struct {
+ in string
+ scheme string
+ name string
+ digest string
+ err string
+ }{
+ {in: "http://m", scheme: "http", name: "m"},
+ {in: "https+insecure://m", scheme: "https+insecure", name: "m"},
+ {in: "http+insecure://m", err: "unsupported scheme"},
+
+ {in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
+
+ {in: "", err: "invalid or missing name"},
+ {in: "m", scheme: "https", name: "m"},
+ {in: "://", err: "invalid or missing name"},
+ {in: "@sha256:deadbeef", err: "invalid digest"},
+ {in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
+ }
+ for _, tt := range cases {
+ t.Run(tt.in, func(t *testing.T) {
+ var r Registry
+ scheme, n, digest, err := r.parseNameExtended(tt.in)
+ if err != nil {
+ if tt.err == "" {
+ t.Errorf("err = %v; want nil", err)
+ } else if !strings.Contains(err.Error(), tt.err) {
+ t.Errorf("err = %v; want %q", err, tt.err)
+ }
+ } else if tt.err != "" {
+ t.Errorf("err = nil; want %q", tt.err)
+ }
+ if err == nil && !n.IsFullyQualified() {
+ t.Errorf("name = %q; want fully qualified", n)
+ }
+
+ if scheme != tt.scheme {
+ t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
+ }
+
+ // smoke-test name is superset of tt.name
+ if !strings.Contains(n.String(), tt.name) {
+ t.Errorf("name = %q; want %q", n, tt.name)
+ }
+
+ tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
+ if digest.String() != tt.digest {
+ t.Errorf("digest = %q; want %q", digest, tt.digest)
+ }
+ })
+ }
+}
+
+func TestUnlink(t *testing.T) {
+ t.Run("found by name", func(t *testing.T) {
+ rc, _ := newClient(t, nil)
+
+ // confirm linked
+ _, err := rc.ResolveLocal("single")
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+
+ // unlink
+ _, err = rc.Unlink("single")
+ testutil.Check(t, err)
+
+ // confirm unlinked
+ _, err = rc.ResolveLocal("single")
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Errorf("err = %v; want fs.ErrNotExist", err)
+ }
+ })
+ t.Run("not found by name", func(t *testing.T) {
+ rc, _ := newClient(t, nil)
+ ok, err := rc.Unlink("manifestNotFound")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if ok {
+ t.Error("expected not found")
+ }
+ })
+}
diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go
index 8e53040ad..69435c406 100644
--- a/server/internal/client/ollama/trace.go
+++ b/server/internal/client/ollama/trace.go
@@ -6,13 +6,20 @@ import (
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
+//
+// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
+// and [Registry.Pull].
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.
//
- // It is called once at the beginning of the download with a zero n and
- // then once per read operation with the number of bytes read so far,
- // and an error if any.
+ // The n argument is the number of bytes transferred so far, and err is
+ // any error that has occurred. If n == 0, and err is nil, the download
+ // or upload has just started. If err is [ErrCached], the download or
+ // upload has been skipped because the blob is already present in the
+ // local cache or remote registry, respectively. Otherwise, if err is
+ // non-nil, the download or upload has failed. When l.Size == n, and
+ // err is nil, the download or upload has completed.
//
// A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run.
diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go
index c21e71d59..6976927c7 100644
--- a/server/internal/cmd/opp/opp.go
+++ b/server/internal/cmd/opp/opp.go
@@ -63,25 +63,28 @@ func main() {
}
flag.Parse()
- c, err := ollama.DefaultCache()
- if err != nil {
- log.Fatal(err)
- }
-
- rc, err := ollama.DefaultRegistry()
- if err != nil {
- log.Fatal(err)
- }
-
ctx := context.Background()
- err = func() error {
+ err := func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
- return cmdPull(ctx, rc, c)
+ rc, err := ollama.DefaultRegistry()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ return cmdPull(ctx, rc)
case "push":
- return cmdPush(ctx, rc, c)
+ rc, err := ollama.DefaultRegistry()
+ if err != nil {
+ log.Fatal(err)
+ }
+ return cmdPush(ctx, rc)
case "import":
+ c, err := ollama.DefaultCache()
+ if err != nil {
+ log.Fatal(err)
+ }
return cmdImport(ctx, c)
default:
if cmd == "" {
@@ -99,7 +102,7 @@ func main() {
}
}
-func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
+func cmdPull(ctx context.Context, rc *ollama.Registry) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
@@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
errc := make(chan error)
go func() {
- errc <- rc.Pull(ctx, c, model)
+ errc <- rc.Pull(ctx, model)
}()
t := time.NewTicker(time.Second)
@@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}
}
-func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
+func cmdPush(ctx context.Context, rc *ollama.Registry) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
@@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}
from := cmp.Or(*flagFrom, model)
- m, err := rc.ResolveLocal(c, from)
+ m, err := rc.ResolveLocal(from)
if err != nil {
return err
}
@@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
},
})
- return rc.Push(ctx, c, model, &ollama.PushParams{
+ return rc.Push(ctx, model, &ollama.PushParams{
From: from,
})
}
diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go
index bb8438a78..11ace22a8 100644
--- a/server/internal/internal/backoff/backoff_test.go
+++ b/server/internal/internal/backoff/backoff_test.go
@@ -1,3 +1,5 @@
+//go:build goexperiment.synctest
+
package backoff
import (
diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go
index 361cce76f..f0a1185dc 100644
--- a/server/internal/internal/names/name.go
+++ b/server/internal/internal/names/name.go
@@ -8,7 +8,7 @@ import (
"github.com/ollama/ollama/server/internal/internal/stringsx"
)
-const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /:
+const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // //:
type Name struct {
// Make incomparable to enfoce use of Compare / Equal for
@@ -25,19 +25,12 @@ type Name struct {
// format of a valid name string is:
//
// s:
-// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag }
-// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
-// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag }
-// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
-// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
-// { model } "@" { digest }
// { model }
-// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
// length: [1, 350]
@@ -50,9 +43,6 @@ type Name struct {
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
-// digest:
-// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
-// length: [1, 80]
//
// The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check
@@ -82,23 +72,17 @@ func Parse(s string) Name {
}
}
-// ParseExtended parses and returns any scheme, Name, and digest from from s in
-// the the form [scheme://][name][@digest]. All parts are optional.
-//
-// If the scheme is present, it must be followed by "://". The digest is
-// prefixed by "@" and comes after the name. The name is parsed using [Parse].
-//
-// The scheme and digest are stripped before the name is parsed by [Parse].
-//
-// For convience, the scheme is never empty. If the scheme is not present, the
-// returned scheme is "https".
+// Split splits an extended name string into its scheme, name, and digest
+// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
-func ParseExtended(s string) (scheme string, _ Name, digest string) {
+// model@digest
+// @digest
+func Split(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
digest = s[i+1:]
s = s[:i]
}
- return scheme, Parse(s), digest
-}
-
-func FormatExtended(scheme string, n Name, digest string) string {
- var b strings.Builder
- if scheme != "" {
- b.WriteString(scheme)
- b.WriteString("://")
- }
- b.WriteString(n.String())
- if digest != "" {
- b.WriteByte('@')
- b.WriteString(digest)
- }
- return b.String()
+ return scheme, s, digest
}
// Merge merges two names into a single name. Non-empty host, namespace, and
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
// IsValid returns true if the name is valid.
func (n Name) IsValid() bool {
- if n.h != "" && !isValidHost(n.h) {
+ if n.h != "" && !isValidPart(partHost, n.h) {
return false
}
- if n.n != "" && !isValidNamespace(n.n) {
+ if n.n != "" && !isValidPart(partNamespace, n.n) {
return false
}
- if n.m != "" && !isValidModel(n.m) {
+ if n.t != "" && !isValidPart(partTag, n.t) {
return false
}
- if n.t != "" && !isValidTag(n.t) {
- return false
- }
- return true
+
+ // at bare minimum, model must be present and valid
+ return n.m != "" && isValidPart(partModel, n.m)
}
func (n Name) IsFullyQualified() bool {
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
}
-func isValidHost(_ string) bool {
- return true // TODO: implement
+const (
+ partHost = iota
+ partNamespace
+ partModel
+ partTag
+)
+
+func isValidPart(kind int, s string) bool {
+ maxlen := 80
+ if kind == partHost {
+ maxlen = 350
+ }
+ if len(s) > maxlen {
+ return false
+ }
+
+ for i := range s {
+ if i == 0 {
+ if !isAlphanumericOrUnderscore(s[i]) {
+ return false
+ }
+ continue
+ }
+ switch s[i] {
+ case '_', '-':
+ case '.':
+ if kind == partNamespace {
+ return false
+ }
+ case ':':
+ if kind != partHost {
+ return false
+ }
+ default:
+ if !isAlphanumericOrUnderscore(s[i]) {
+ return false
+ }
+ }
+ }
+ return true
}
-func isValidNamespace(_ string) bool {
- return true // TODO: implement
-}
-
-func isValidModel(_ string) bool {
- return true // TODO: implement
-}
-
-func isValidTag(_ string) bool {
- return true // TODO: implement
+func isAlphanumericOrUnderscore(c byte) bool {
+ return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
}
func (n Name) Host() string { return n.h }
diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go
index 760fec5fa..e3dc5fe3c 100644
--- a/server/internal/internal/names/name_test.go
+++ b/server/internal/internal/names/name_test.go
@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
- scheme, name, digest := ParseExtended(tt.in)
- if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
+ scheme, name, digest := Split(tt.in)
+ n := Parse(name)
+ if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
}
-
- // Round trip
- if got := FormatExtended(scheme, name, digest); got != tt.in {
- t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
- }
})
}
}
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
junkName = Parse("h/n/m:t")
}
}
+
+const (
+ part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
+ part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
+)
+
+var testCases = map[string]bool{ // name -> valid
+ "": false,
+
+ "_why/_the/_lucky:_stiff": true,
+
+ // minimal
+ "h/n/m:t": true,
+
+ "host/namespace/model:tag": true,
+ "host/namespace/model": true,
+ "namespace/model": true,
+ "model": true,
+
+ // long (but valid)
+ part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
+ part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
+
+ // too long
+ part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
+ "x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
+
+ "h/nn/mm:t": true, // bare minimum part sizes
+
+ // unqualified
+ "m": true,
+ "n/m:": true,
+ "h/n/m": true,
+ "@t": false,
+ "m@d": false,
+
+ // invalids
+ "^": false,
+ "mm:": true,
+ "/nn/mm": true,
+ "//": false, // empty model
+ "//mm": true,
+ "hh//": false, // empty model
+ "//mm:@": false,
+ "00@": false,
+ "@": false,
+
+ // not starting with alphanum
+ "-hh/nn/mm:tt": false,
+ "hh/-nn/mm:tt": false,
+ "hh/nn/-mm:tt": false,
+ "hh/nn/mm:-tt": false,
+
+ // smells like a flag
+ "-h": false,
+
+ // hosts
+ "host:https/namespace/model:tag": true,
+
+ // colon in non-host part before tag
+ "host/name:space/model:tag": false,
+}
+
+func TestParseNameValidation(t *testing.T) {
+ for s, valid := range testCases {
+ got := Parse(s)
+ if got.IsValid() != valid {
+ t.Logf("got: %v", got)
+ t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
+ }
+ }
+}
diff --git a/server/internal/internal/syncs/line_test.go b/server/internal/internal/syncs/line_test.go
index d52160260..94114a565 100644
--- a/server/internal/internal/syncs/line_test.go
+++ b/server/internal/internal/syncs/line_test.go
@@ -1,3 +1,5 @@
+//go:build goexperiment.synctest
+
package syncs
import (
diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go
index 8eb6daf89..62fefb4c7 100644
--- a/server/internal/registry/server.go
+++ b/server/internal/registry/server.go
@@ -7,9 +7,12 @@ import (
"cmp"
"encoding/json"
"errors"
+ "fmt"
"io"
"log/slog"
"net/http"
+ "sync"
+ "time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
@@ -27,12 +30,15 @@ import (
// directly to the blob disk cache.
type Local struct {
Client *ollama.Registry // required
- Cache *blob.DiskCache // required
Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by
// this handler.
Fallback http.Handler
+
+ // Prune, if set, is called to prune the local disk cache after a model
+ // is deleted.
+ Prune func() error // optional
}
// serverError is like ollama.Error, but with a Status field for the HTTP
@@ -107,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
switch r.URL.Path {
case "/api/delete":
return false, s.handleDelete(rec, r)
+ case "/api/pull":
+ return false, s.handlePull(rec, r)
default:
if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r)
@@ -199,13 +207,107 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if err != nil {
return err
}
- ok, err := s.Client.Unlink(s.Cache, p.model())
+ ok, err := s.Client.Unlink(p.model())
if err != nil {
return err
}
if !ok {
- return &serverError{404, "manifest_not_found", "manifest not found"}
+ return &serverError{404, "not_found", "model not found"}
}
+ if s.Prune == nil {
+ return nil
+ }
+ return s.Prune()
+}
+
+type progressUpdateJSON struct {
+ Status string `json:"status"`
+ Digest blob.Digest `json:"digest,omitempty,omitzero"`
+ Total int64 `json:"total,omitempty,omitzero"`
+ Completed int64 `json:"completed,omitempty,omitzero"`
+}
+
+func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
+ if r.Method != "POST" {
+ return errMethodNotAllowed
+ }
+
+ p, err := decodeUserJSON[*params](r.Body)
+ if err != nil {
+ return err
+ }
+
+ maybeFlush := func() {
+ fl, _ := w.(http.Flusher)
+ if fl != nil {
+ fl.Flush()
+ }
+ }
+ defer maybeFlush()
+
+ var mu sync.Mutex
+ enc := json.NewEncoder(w)
+ enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
+
+ ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
+ Update: func(l *ollama.Layer, n int64, err error) {
+ mu.Lock()
+ defer mu.Unlock()
+
+ // TODO(bmizerany): coalesce these updates; writing per
+ // update is expensive
+ enc.Encode(progressUpdateJSON{
+ Digest: l.Digest,
+ Status: "pulling",
+ Total: l.Size,
+ Completed: n,
+ })
+ },
+ })
+
+ done := make(chan error, 1)
+ go func() {
+ // TODO(bmizerany): continue to support non-streaming responses
+ done <- s.Client.Pull(ctx, p.model())
+ }()
+
+ func() {
+ t := time.NewTicker(100 * time.Millisecond)
+ defer t.Stop()
+ for {
+ select {
+ case <-t.C:
+ mu.Lock()
+ maybeFlush()
+ mu.Unlock()
+ case err := <-done:
+ if err != nil {
+ var status string
+ if errors.Is(err, ollama.ErrModelNotFound) {
+ status = fmt.Sprintf("error: model %q not found", p.model())
+ enc.Encode(progressUpdateJSON{Status: status})
+ } else {
+ status = fmt.Sprintf("error: %v", err)
+ enc.Encode(progressUpdateJSON{Status: status})
+ }
+ return
+ }
+
+ // These final updates are not strictly necessary, because they have
+ // already happened at this point. Our pull handler code used to do
+ // these steps after, not during, the pull, and they were slow, so we
+ // wanted to provide feedback to users what was happening. For now, we
+ // keep them to not jar users who are used to seeing them. We can phase
+ // them out with a new and nicer UX later. One without progress bars
+ // and digests that no one cares about.
+ enc.Encode(progressUpdateJSON{Status: "verifying layers"})
+ enc.Encode(progressUpdateJSON{Status: "writing manifest"})
+ enc.Encode(progressUpdateJSON{Status: "success"})
+ return
+ }
+ }
+ }()
+
return nil
}
diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go
index 22267ba7d..597e9bd63 100644
--- a/server/internal/registry/server_test.go
+++ b/server/internal/registry/server_test.go
@@ -1,17 +1,27 @@
package registry
import (
+ "bytes"
+ "context"
"encoding/json"
+ "fmt"
+ "io"
+ "io/fs"
+ "net"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strings"
+ "sync"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil"
+ "golang.org/x/tools/txtar"
+
+ _ "embed"
)
type panicTransport struct{}
@@ -30,7 +40,7 @@ type bytesResetter interface {
Reset()
}
-func newTestServer(t *testing.T) *Local {
+func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
t.Helper()
dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models"))
@@ -41,11 +51,26 @@ func newTestServer(t *testing.T) *Local {
if err != nil {
t.Fatal(err)
}
- rc := &ollama.Registry{
- HTTPClient: panicOnRoundTrip,
+
+ client := panicOnRoundTrip
+ if upstreamRegistry != nil {
+ s := httptest.NewTLSServer(upstreamRegistry)
+ t.Cleanup(s.Close)
+ tr := s.Client().Transport.(*http.Transport).Clone()
+ tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
+ var d net.Dialer
+ return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
+ }
+ client = &http.Client{Transport: tr}
}
+
+ rc := &ollama.Registry{
+ Cache: c,
+ HTTPClient: client,
+ Mask: "example.com/library/_:latest",
+ }
+
l := &Local{
- Cache: c,
Client: rc,
Logger: testutil.Slogger(t),
}
@@ -85,9 +110,9 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
func TestServerDelete(t *testing.T) {
check := testutil.Checker(t)
- s := newTestServer(t)
+ s := newTestServer(t, nil)
- _, err := s.Client.ResolveLocal(s.Cache, "smol")
+ _, err := s.Client.ResolveLocal("smol")
check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
@@ -95,7 +120,7 @@ func TestServerDelete(t *testing.T) {
t.Fatalf("Code = %d; want 200", got.Code)
}
- _, err = s.Client.ResolveLocal(s.Cache, "smol")
+ _, err = s.Client.ResolveLocal("smol")
if err == nil {
t.Fatal("expected smol to have been deleted")
}
@@ -109,11 +134,8 @@ func TestServerDelete(t *testing.T) {
got = s.send(t, "DELETE", "/api/delete", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
- got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
- checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
-
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
- checkErrorResponse(t, got, 400, "bad_request", "invalid name")
+ checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
checkErrorResponse(t, got, 404, "not_found", "not found")
@@ -130,8 +152,105 @@ func TestServerDelete(t *testing.T) {
}
}
+//go:embed testdata/registry.txt
+var registryTXT []byte
+
+var registryFS = sync.OnceValue(func() fs.FS {
+ // Txtar gets hung up on \r\n line endings, so we need to convert them
+ // to \n when parsing the txtar on Windows.
+ data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
+ a := txtar.Parse(data)
+ fmt.Printf("%q\n", a.Comment)
+ fsys, err := txtar.FS(a)
+ if err != nil {
+ panic(err)
+ }
+ return fsys
+})
+
+func TestServerPull(t *testing.T) {
+ modelsHandler := http.FileServerFS(registryFS())
+ s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/v2/library/BOOM/manifests/latest":
+ w.WriteHeader(999)
+ io.WriteString(w, `{"error": "boom"}`)
+ case "/v2/library/unknown/manifests/latest":
+ w.WriteHeader(404)
+ io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
+ default:
+ t.Logf("serving file: %s", r.URL.Path)
+ modelsHandler.ServeHTTP(w, r)
+ }
+ })
+
+ checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
+ t.Helper()
+
+ if got.Code != 200 {
+ t.Fatalf("Code = %d; want 200", got.Code)
+ }
+ gotlines := got.Body.String()
+ t.Logf("got:\n%s", gotlines)
+ for want := range strings.Lines(wantlines) {
+ want = strings.TrimSpace(want)
+ want, unwanted := strings.CutPrefix(want, "!")
+ want = strings.TrimSpace(want)
+ if !unwanted && !strings.Contains(gotlines, want) {
+ t.Fatalf("! missing %q in body", want)
+ }
+ if unwanted && strings.Contains(gotlines, want) {
+ t.Fatalf("! unexpected %q in body", want)
+ }
+ }
+ }
+
+ got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
+ checkResponse(got, `
+ {"status":"pulling manifest"}
+ {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
+ `)
+
+ got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
+ checkResponse(got, `
+ {"status":"pulling manifest"}
+ {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
+ {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
+ {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
+ {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
+ {"status":"verifying layers"}
+ {"status":"writing manifest"}
+ {"status":"success"}
+ `)
+
+ got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
+ checkResponse(got, `
+ {"status":"pulling manifest"}
+ {"status":"error: model \"unknown\" not found"}
+ `)
+
+ got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
+ checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
+
+ got = s.send(t, "POST", "/api/pull", `!`)
+ checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
+
+ got = s.send(t, "POST", "/api/pull", ``)
+ checkErrorResponse(t, got, 400, "bad_request", "empty request body")
+
+ got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
+ checkResponse(got, `
+ {"status":"pulling manifest"}
+ {"status":"error: invalid or missing name: \"\""}
+
+ !verifying
+ !writing
+ !success
+ `)
+}
+
func TestServerUnknownPath(t *testing.T) {
- s := newTestServer(t)
+ s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
}
diff --git a/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest b/server/internal/registry/testdata/models/manifests/example.com/library/smol/latest
similarity index 100%
rename from server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
rename to server/internal/registry/testdata/models/manifests/example.com/library/smol/latest
diff --git a/server/internal/registry/testdata/registry.txt b/server/internal/registry/testdata/registry.txt
new file mode 100644
index 000000000..2fc363fcb
--- /dev/null
+++ b/server/internal/registry/testdata/registry.txt
@@ -0,0 +1,22 @@
+-- v2/library/smol/manifests/latest --
+{
+ "schemaVersion": 2,
+ "mediaType": "application/vnd.docker.distribution.manifest.v2+json",
+ "config": {
+ "mediaType": "application/vnd.docker.container.image.v1+json",
+ "digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
+ "size": 3
+ },
+ "layers": [
+ {
+ "mediaType": "application/vnd.ollama.image.model",
+ "digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
+ "size": 5
+ }
+ ]
+}
+
+-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
+GGUF
+-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
+{}
diff --git a/server/prompt.go b/server/prompt.go
index 233dffd69..5b5b958f1 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -10,7 +10,6 @@ import (
"strings"
"github.com/ollama/ollama/api"
- "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/template"
@@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var imgData llm.ImageData
if isMllama {
- if envconfig.NewEngine() {
+ if len(m.ProjectorPaths) == 0 {
imgData = llm.ImageData{
ID: len(images),
Data: i,
diff --git a/server/routes.go b/server/routes.go
index ff42000f8..3efa12e43 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai"
- "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -43,6 +42,12 @@ import (
"github.com/ollama/ollama/version"
)
+func experimentEnabled(name string) bool {
+ return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
+}
+
+var useClient2 = experimentEnabled("client2")
+
var mode string = gin.DebugMode
type Server struct {
@@ -206,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
- if isMllama && !envconfig.NewEngine() {
+ if isMllama && len(model.ProjectorPaths) > 0 {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
@@ -1129,7 +1134,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
}
-func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
+func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
@@ -1174,6 +1179,7 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
+ r.DELETE("/api/delete", s.DeleteHandler)
// Create
r.POST("/api/create", s.CreateHandler)
@@ -1195,15 +1201,19 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
- // wrap old with new
- rs := ®istry.Local{
- Cache: c,
- Client: rc,
- Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
- Fallback: r,
+ if rc != nil {
+ // wrap old with new
+ rs := ®istry.Local{
+ Client: rc,
+ Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
+ Fallback: r,
+
+ Prune: PruneLayers,
+ }
+ return rs, nil
}
- return rs, nil
+ return r, nil
}
func Serve(ln net.Listener) error {
@@ -1258,19 +1268,20 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
- c, err := ollama.DefaultCache()
- if err != nil {
- return err
+ var rc *ollama.Registry
+ if useClient2 {
+ var err error
+ rc, err = ollama.DefaultRegistry()
+ if err != nil {
+ return err
+ }
}
- rc, err := ollama.DefaultRegistry()
+
+ h, err := s.GenerateRoutes(rc)
if err != nil {
return err
}
- h, err := s.GenerateRoutes(c, rc)
- if err != nil {
- return err
- }
http.Handle("/", h)
ctx, done := context.WithCancel(context.Background())
diff --git a/server/routes_test.go b/server/routes_test.go
index 0dd782f4f..e13c4b599 100644
--- a/server/routes_test.go
+++ b/server/routes_test.go
@@ -23,7 +23,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai"
- "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir)
- c, err := blob.Open(modelsDir)
- if err != nil {
- t.Fatalf("failed to open models dir: %v", err)
- }
-
rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended
@@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
}
s := &Server{}
- router, err := s.GenerateRoutes(c, rc)
+ router, err := s.GenerateRoutes(rc)
if err != nil {
t.Fatalf("failed to generate routes: %v", err)
}