diff --git a/CMakeLists.txt b/CMakeLists.txt index 92b1793b6..034fc7d79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,9 +106,11 @@ if(CMAKE_HIP_COMPILER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) if (WIN32) - target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1) + target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) endif() + target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM) + set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip RUNTIME_DEPENDENCIES diff --git a/Dockerfile b/Dockerfile index 46d4713e7..4136fca71 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base RUN yum install -y yum-utils \ && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ - && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \ + && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH @@ -86,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --install build --component CUDA --strip --parallel 8 FROM base AS build -ARG GOVERSION=1.23.4 -RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local -ENV PATH=/usr/local/go/bin:$PATH WORKDIR /go/src/github.com/ollama/ollama +COPY go.mod go.sum . +RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local +ENV PATH=/usr/local/go/bin:$PATH +RUN go mod download COPY . . ARG GOFLAGS="'-ldflags=-w -s'" ENV CGO_ENABLED=1 diff --git a/README.md b/README.md index 8d471f5ef..b4df5e2af 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
ollama
@@ -54,6 +54,7 @@ Here are some example models that can be downloaded: | Model | Parameters | Size | Download | | ------------------ | ---------- | ----- | -------------------------------- | +| QwQ | 32B | 20GB | `ollama run qwq` | | DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` | | DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` | | Llama 3.3 | 70B | 43GB | `ollama run llama3.3` | @@ -64,7 +65,7 @@ Here are some example models that can be downloaded: | Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | | Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | | Phi 4 | 14B | 9.1GB | `ollama run phi4` | -| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | +| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` | | Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` | | Gemma 2 | 9B | 5.5GB | `ollama run gemma2` | | Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` | @@ -75,7 +76,7 @@ Here are some example models that can be downloaded: | Code Llama | 7B | 3.8GB | `ollama run codellama` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | LLaVA | 7B | 4.5GB | `ollama run llava` | -| Solar | 10.7B | 6.1GB | `ollama run solar` | +| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` | > [!NOTE] > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. @@ -275,6 +276,7 @@ See the [API documentation](./docs/api.md) for all endpoints. ### Web & Desktop - [Open WebUI](https://github.com/open-webui/open-webui) +- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [Hollama](https://github.com/fmaclen/hollama) - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui) @@ -387,6 +389,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) +- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) +- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) ### Cloud @@ -430,6 +434,7 @@ See the [API documentation](./docs/api.md) for all endpoints. ### Apple Vision Pro +- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad") - [Enchanted](https://github.com/AugustDev/enchanted) ### Database @@ -507,10 +512,13 @@ See the [API documentation](./docs/api.md) for all endpoints. ### Mobile +- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad) - [Enchanted](https://github.com/AugustDev/enchanted) - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) +- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device) +- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) ### Extensions & Plugins @@ -556,12 +564,14 @@ See the [API documentation](./docs/api.md) for all endpoints. - [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language - [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) +- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs) ### Supported backends - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. ### Observability +- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production. diff --git a/api/types.go b/api/types.go index 14bb7eab4..bc5aad8f1 100644 --- a/api/types.go +++ b/api/types.go @@ -403,9 +403,9 @@ type CopyRequest struct { // PullRequest is the request passed to [Client.Pull]. type PullRequest struct { Model string `json:"model"` - Insecure bool `json:"insecure,omitempty"` - Username string `json:"username"` - Password string `json:"password"` + Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored + Username string `json:"username"` // Deprecated: ignored + Password string `json:"password"` // Deprecated: ignored Stream *bool `json:"stream,omitempty"` // Deprecated: set the model name with Model instead diff --git a/cmd/cmd.go b/cmd/cmd.go index 80ece4c60..c22a08f43 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" - "github.com/ollama/ollama/llama" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/runner" @@ -256,6 +255,7 @@ func StopHandler(cmd *cobra.Command, args []string) error { if strings.Contains(err.Error(), "not found") { return fmt.Errorf("couldn't find model \"%s\" to stop", args[0]) } + return err } return nil } @@ -338,10 +338,16 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - // TODO(jessegross): We should either find another way to know if this is - // a vision model or remove the logic. Also consider that other modalities will - // need different behavior anyways. - opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine() + if len(info.ProjectorInfo) != 0 { + opts.MultiModal = true + } + for k := range info.ModelInfo { + if strings.Contains(k, ".vision.") { + opts.MultiModal = true + break + } + } + opts.ParentModel = info.Details.ParentModel if interactive { @@ -1274,7 +1280,6 @@ func NewCLI() *cobra.Command { runnerCmd := &cobra.Command{ Use: "runner", - Short: llama.PrintSystemInfo(), Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { return runner.Execute(os.Args[1:]) diff --git a/docs/development.md b/docs/development.md index eb67dbfaf..cf6d91e27 100644 --- a/docs/development.md +++ b/docs/development.md @@ -118,6 +118,35 @@ To run tests, use `go test`: go test ./... ``` +> NOTE: In rare cirumstances, you may nedd to change a package using the new +> "synctest" package in go1.24. +> +> If you do not have the "synctest" package enabled, you will not see build or +> test failures resulting from your change(s), if any, locally, but CI will +> break. +> +> If you see failures in CI, you can either keep pushing changes to see if the +> CI build passes, or you can enable the "synctest" package locally to see the +> failures before pushing. +> +> To enable the "synctest" package for testing, run the following command: +> +> ```shell +> GOEXPERIMENT=synctest go test ./... +> ``` +> +> If you wish to enable synctest for all go commands, you can set the +> `GOEXPERIMENT` environment variable in your shell profile or by using: +> +> ```shell +> go env -w GOEXPERIMENT=synctest +> ``` +> +> Which will enable the "synctest" package for all go commands without needing +> to set it for all shell sessions. +> +> The synctest package is not required for production builds. + ## Library detection Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable: diff --git a/docs/faq.md b/docs/faq.md index 04e8433de..4aaccc2e4 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md). ## How can I specify the context window size? -By default, Ollama uses a context window size of 2048 tokens. +By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`. To change this when using `ollama run`, use `/set parameter`: diff --git a/docs/linux.md b/docs/linux.md index 12581bdd2..2dda87f3a 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -75,7 +75,7 @@ RestartSec=3 Environment="PATH=$PATH" [Install] -WantedBy=default.target +WantedBy=multi-user.target ``` Then start the service: diff --git a/docs/windows.md b/docs/windows.md index 018cc41d0..78b99a5d7 100644 --- a/docs/windows.md +++ b/docs/windows.md @@ -81,9 +81,11 @@ help you keep up to date. If you'd like to install or integrate Ollama as a service, a standalone `ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI -and GPU library dependencies for Nvidia and AMD. This allows for embedding -Ollama in existing applications, or running it as a system service via `ollama -serve` with tools such as [NSSM](https://nssm.cc/). +and GPU library dependencies for Nvidia. If you have an AMD GPU, also download +and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the +same directory. This allows for embedding Ollama in existing applications, or +running it as a system service via `ollama serve` with tools such as +[NSSM](https://nssm.cc/). > [!NOTE] > If you are upgrading from a prior version, you should remove the old directories first. diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index b9f9cc178..8662c3b01 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -565,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO return } +func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { + switch llm.KV().Architecture() { + case "mllama": + for _, layer := range llm.Tensors().GroupLayers()["v"] { + weights += layer.Size() + } + + kv := func(n string) uint64 { + if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok { + return uint64(v) + } + + return 0 + } + + imageSize := kv("image_size") + + maxNumTiles := kv("max_num_tiles") + embeddingLength := kv("embedding_length") + headCount := kv("attention.head_count") + + numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size")) + if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok { + numPatches++ + } + + numPaddedPatches := numPatches + 8 - (numPatches%8)%8 + + graphSize = 4 * (8 + + imageSize*imageSize*kv("num_channels")*maxNumTiles + + embeddingLength*numPatches*maxNumTiles + + 9*embeddingLength*numPaddedPatches*maxNumTiles + + numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) + } + return weights, graphSize +} + // SupportsKVCacheType checks if the requested cache type is supported func (f GGML) SupportsKVCacheType(cacheType string) bool { return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) diff --git a/go.mod b/go.mod index af0cedc86..cc5789005 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c golang.org/x/image v0.22.0 - gonum.org/v1/gonum v0.15.0 + golang.org/x/tools v0.30.0 ) require ( @@ -44,6 +44,7 @@ require ( github.com/xtgo/set v1.0.0 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gonum.org/v1/gonum v0.15.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect ) diff --git a/go.sum b/go.sum index 013a7db71..0ab97b909 100644 --- a/go.sum +++ b/go.sum @@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/kvcache/cache.go b/kvcache/cache.go index 5d8b2f9b5..d35489057 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) var ( @@ -29,6 +30,17 @@ type Cache interface { // cache implementation used. Put(ctx ml.Context, key, value ml.Tensor) + // SetConfig controls optimizations (mostly backend-specific) that may transform + // the output of the cache to work better with specific kernels. If not called, + // the backend settings will be used. This works well when calling Attention. + // + // The config can be overridden by models, especially if they require vanilla + // output when implementing their own version of attention. To do this, pass + // an empty ml.CacheConfig. + // + // Most models will not need to use this. + SetConfig(ml.CacheConfig) + // ** cache management ** // Init sets up runtime parameters @@ -40,7 +52,7 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding // entry in positions and seqs. - StartForward(ctx ml.Context, positions []int32, seqs []int) error + StartForward(ctx ml.Context, opts input.Options) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index 69068439e..34d5337cf 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -8,6 +8,7 @@ import ( "slices" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) @@ -20,8 +21,12 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e type Causal struct { DType ml.DType Capacity int32 + causal bool windowSize int32 + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -39,6 +44,12 @@ type Causal struct { // locations in the cache that are needed for this batch curCellRange cellRange + // curSequences is the sequences corresponding to this pass's entries in the cache + curSequences []int + + // curPositions is the positions corresponding to this pass's entries in the cache + curPositions []int32 + // ** cache metadata ** // for each possible location in the cache, stores the position and set of sequences @@ -52,8 +63,8 @@ type Causal struct { shiftFn shiftFn backend ml.Backend - cacheCtx ml.Context - keys, values []ml.Tensor + ctxs map[int]ml.Context + keys, values map[int]ml.Tensor } type cacheCell struct { @@ -67,28 +78,73 @@ type cellRange struct { } func NewCausalCache(shift shiftFn) *Causal { - return &Causal{windowSize: math.MaxInt32, shiftFn: shift} + return &Causal{ + causal: true, + windowSize: math.MaxInt32, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } } func NewSWACache(windowSize int32, shift shiftFn) *Causal { - return &Causal{windowSize: windowSize, shiftFn: shift} + return &Causal{ + causal: true, + windowSize: windowSize, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } } func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding == 0 { + c.config.CachePadding = 1 + } + + if c.config.MaskBatchPadding == 0 { + c.config.MaskBatchPadding = 1 + } + + if c.config.MaskDType == ml.DTypeOther { + c.config.MaskDType = ml.DTypeF32 + } + c.DType = dtype - c.Capacity = capacity - c.cells = make([]cacheCell, capacity) + c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) + c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend - c.cacheCtx = backend.NewContext() +} + +func (c *Causal) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config } func (c *Causal) Close() { - c.cacheCtx.Close() + for _, ctx := range c.ctxs { + ctx.Close() + } } -func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - c.curBatchSize = len(positions) +func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { + c.curBatchSize = len(opts.Positions) + c.curSequences = opts.Sequences + c.curPositions = opts.Positions var err error c.curLoc, err = c.findStartLoc() @@ -101,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err } c.curCellRange = newRange() - for i, pos := range positions { - seq := seqs[i] + for i, pos := range opts.Positions { + seq := opts.Sequences[i] c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} @@ -127,7 +183,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err c.cellRanges[seq] = seqRange } - c.curMask, err = c.buildMask(ctx, positions, seqs) + c.curMask, err = c.buildMask(ctx) return err } @@ -157,36 +213,90 @@ func (c *Causal) findStartLoc() (int, error) { return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) } +func roundDown(length, pad int) int { + return (length / pad) * pad +} + +func roundUp(length, pad int) int { + return ((length + pad - 1) / pad) * pad +} + // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). -func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { - // TODO(jessegross): This does not do padding, which is required for flash attention - len := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, c.curBatchSize*len) +func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { + // Align and pad the two dimensions as required by the backend + batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) + + c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) + c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 + + length := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, batchSize*length) for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { - if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || - c.cells[j].pos < positions[i]-c.windowSize { - mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || + (c.causal && c.cells[j].pos > c.curPositions[i]) || + c.cells[j].pos < c.curPositions[i]-c.windowSize { + mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } } } - return ctx.FromFloatSlice(mask, len, c.curBatchSize) + // Mask out any padding tokens we added. For padding that we added to the cache history, this + // has already been masked out because the sequence doesn't match. + for i := c.curBatchSize * length; i < len(mask); i++ { + mask[i] = float32(math.Inf(-1)) + } + + maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) + if err != nil { + return nil, err + } + + if c.config.MaskDType != ml.DTypeF32 { + out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) + ctx.Forward(maskTensor.Copy(ctx, out)) + maskTensor = out + } + + return maskTensor, nil } -func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { - for _, obj := range objs { - if obj == nil { +func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { + for i, key := range c.keys { + if key == nil { continue } - srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) - dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len) + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) - ctx.Forward(srcView.Copy(ctx, dstView)) + kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) + kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + + value := c.values[i] + var vSrcView, vDstView ml.Tensor + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) + vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + } + + ctx.Forward( + kSrcView.Copy(ctx, kDstView), + vSrcView.Copy(ctx, vDstView), + ) } } @@ -219,7 +329,7 @@ func (c *Causal) defrag() { layers++ } - maxMoves := ctx.MaxTensors() / (6 * layers) + maxMoves := ctx.MaxGraphNodes() / (6 * layers) moves := 0 var pendingSrc, pendingDst, pendingLen int @@ -238,8 +348,7 @@ func (c *Causal) defrag() { pendingLen++ break } else { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } } @@ -263,8 +372,7 @@ func (c *Causal) defrag() { } if pendingLen > 0 { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } @@ -293,47 +401,106 @@ func (c *Causal) defrag() { } func (c *Causal) SetLayer(layer int) { - if layer >= len(c.keys) { - c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) - c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) - } - c.curLayer = layer } +// SetCausal enables or disables causal mask generation for subsequent calls to Get. +// This state carries over to future forward passes. The default value is true. +// +// ctx may be set to nil if this is called from outside of a forward pass, for +// example, when initializing the cache. +func (c *Causal) SetCausal(ctx ml.Context, causal bool) { + if c.causal != causal { + c.causal = causal + + if ctx != nil { + var err error + c.curMask, err = c.buildMask(ctx) + if err != nil { + // This error should never occur because we have previously built a mask with the same shape + panic(fmt.Errorf("SetCausal: %w", err)) + } + } + } +} + func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { key := c.keys[c.curLayer] value := c.values[c.curLayer] - key = key.View(ctx, key.Stride(2)*c.curCellRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), - c.curMask.Dim(0), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + cachedSize := c.curMask.Dim(0) + + key = key.View(ctx, rowSize*c.curCellRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), + cachedSize, ) - value = value.View(ctx, key.Stride(2)*c.curCellRange.min, - value.Dim(0), value.Stride(1), - value.Dim(1), value.Stride(2), - c.curMask.Dim(0), - ) + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + value = value.View(ctx, elemSize*c.curCellRange.min, + cachedSize, value.Stride(1), + vHeadDim, value.Stride(2), + numKVHeads, + ) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + value = value.View(ctx, rowSize*c.curCellRange.min, + vHeadDim, value.Stride(1), + numKVHeads, value.Stride(2), + cachedSize, + ) + } return key, value, c.curMask } func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { - if c.curBatchSize != key.Dim(2) { - panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) + kHeadDim := key.Dim(0) + vHeadDim := value.Dim(0) + numKVHeads := key.Dim(1) + batchSize := key.Dim(2) + + if c.curBatchSize != batchSize { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) } - if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) - c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) + if _, ok := c.ctxs[c.curLayer]; !ok { + c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer) } - ctx.Forward( - key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), - value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))), - ) + if _, ok := c.keys[c.curLayer]; !ok { + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + } + + if _, ok := c.values[c.curLayer]; !ok { + if c.config.PermutedV { + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + } else { + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + } + } + + rowSize := c.keys[c.curLayer].Stride(2) + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) + + if c.config.PermutedV { + elemSize := c.values[c.curLayer].Stride(0) + + value = value.Permute(ctx, 1, 2, 0, 3) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + } else { + rowSize := c.values[c.curLayer].Stride(2) + + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) + } } func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { @@ -379,7 +546,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } } - kShift, err := ctx.FromIntSlice(offsets, len(offsets)) + kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) if err != nil { return err } @@ -389,9 +556,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { continue } - key = key.View(ctx, key.Stride(2)*seqRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + key = key.View(ctx, rowSize*seqRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), size, ) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index bd7d0ae8b..22d8efb43 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type testCase struct { @@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, test.pos, test.seqs) + err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) if err != nil { panic(err) } @@ -303,13 +304,17 @@ func (b *testBackend) NewContext() ml.Context { return &testContext{} } +func (b *testBackend) NewContextSize(int) ml.Context { + return &testContext{} +} + func (b *testBackend) SystemInfo() string { return "not implemented" } type testContext struct{} -func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { total := 0 if len(shape) > 0 { @@ -322,8 +327,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} } +func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + return c.Empty(dtype, shape...) +} + func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - t := c.Zeros(ml.DTypeF32, shape...).(*testTensor) + t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) @@ -342,11 +351,15 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return out, nil } +func (c *testContext) Input() ml.Context { return c } +func (c *testContext) Output() ml.Context { return c } +func (c *testContext) Layer(int) ml.Context { return c } + func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} -func (c *testContext) MaxTensors() int { +func (c *testContext) MaxGraphNodes() int { return 10 } @@ -391,7 +404,7 @@ func (t *testTensor) Floats() []float32 { } func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor) + out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) for i := range out.data { out.data[i] = t.data[i] + t2.(*testTensor).data[i] @@ -468,7 +481,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { context := &testContext{} - view := context.Zeros(t.dtype, s...).(*testTensor) + view := context.Empty(t.dtype, s...).(*testTensor) view.data = t.data[offset : offset+len(view.data)] return view diff --git a/kvcache/encoder.go b/kvcache/encoder.go index b85b1046a..6a9df2abc 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -1,7 +1,10 @@ package kvcache import ( + "fmt" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Encoder cache stores K and V tensors that are position independent @@ -11,6 +14,9 @@ import ( // // Not currently safe for multiple sequences type EncoderCache struct { + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -30,36 +36,59 @@ type EncoderCache struct { encoderPos int32 // ** cache data storage ** - - cacheCtx ml.Context - keys, values []ml.Tensor + backend ml.Backend + ctxs map[int]ml.Context + keys, values map[int]ml.Tensor } func NewEncoderCache() *EncoderCache { - return &EncoderCache{} + return &EncoderCache{ + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } } func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { - c.cacheCtx = backend.NewContext() + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding != 0 && c.config.CachePadding != 1 { + panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) + } + + c.backend = backend +} + +func (c *EncoderCache) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config } func (c *EncoderCache) Close() { - c.cacheCtx.Close() + for _, ctx := range c.ctxs { + ctx.Close() + } } -func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - // The image is always in the first position - c.curPos = positions[0] +func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { + // We work with the most recent image + if len(opts.Multimodal) > 0 { + c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] + } return nil } func (c *EncoderCache) SetLayer(layer int) { - if layer >= len(c.keys) { - c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) - c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) - } - c.curLayer = layer } @@ -75,9 +104,20 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { c.encoderPos = c.curPos c.encoderCached = true - if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) - c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) + if c.config.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3) + } + + if _, ok := c.ctxs[c.curLayer]; !ok { + c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer) + } + + if _, ok := c.keys[c.curLayer]; !ok { + c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...) + } + + if _, ok := c.values[c.curLayer]; !ok { + c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...) } ctx.Forward( diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 2d4c1089a..aaccd1661 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -4,6 +4,7 @@ import ( "math" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Wrapper cache is a container for multiple types of caches, @@ -28,20 +29,26 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) } } +func (c *WrapperCache) SetConfig(config ml.CacheConfig) { + for _, cache := range c.caches { + cache.SetConfig(config) + } +} + func (c *WrapperCache) Close() { for _, cache := range c.caches { cache.Close() } } -func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { +func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, positions, seqs) + err := cache.StartForward(ctx, opts) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { - for k := range positions { - _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) + for k := range opts.Positions { + _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) } } return err diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index c7ff28be1..7a185443a 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #ifdef IS_BIG_ENDIAN diff --git a/llama/llama.go b/llama/llama.go index 0c4fca430..a026bee24 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -21,18 +21,6 @@ package llama extern bool llamaProgressCallback(float progress, void *user_data); extern void llamaLog(int level, char* text, void* user_data); - -typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER; -COMPILER inline get_compiler() { -#if defined(__clang__) - return COMP_CLANG; -#elif defined(__GNUC__) - return COMP_GCC; -#else - return UNKNOWN_COMPILER; -#endif -} - */ import "C" @@ -72,19 +60,6 @@ func BackendInit() { C.llama_backend_init() } -func PrintSystemInfo() string { - var compiler string - switch C.get_compiler() { - case C.COMP_UNKNOWN: - compiler = "cgo(unknown_compiler)" - case C.COMP_GCC: - compiler = "cgo(gcc)" - case C.COMP_CLANG: - compiler = "cgo(clang)" - } - return C.GoString(C.llama_print_system_info()) + compiler -} - func GetModelArch(modelPath string) (string, error) { mp := C.CString(modelPath) defer C.free(unsafe.Pointer(mp)) @@ -270,6 +245,20 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) { return &m, nil } +func LoadVocabFromFile(path string) (*Vocab, error) { + mp := C.CString(path) + defer C.free(unsafe.Pointer(mp)) + v := Vocab{c: C.llama_load_vocab_from_file(mp)} + if v.c == nil { + return nil, fmt.Errorf("unable to load vocab: %s", path) + } + return &v, nil +} + +func FreeVocab(vocab *Vocab) { + C.llama_free_vocab(vocab.c) +} + func FreeModel(model *Model) { C.llama_model_free(model.c) } @@ -318,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float return nil } +type Vocab struct { + c *C.struct_llama_vocab +} + func (m *Model) Vocab() *C.struct_llama_vocab { return C.llama_model_get_vocab(m.c) } @@ -694,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte { } return buf[:n] } + +type Sampler struct { + c *C.struct_llama_sampler +} + +func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler { + cGrammar := C.CString(grammar) + cRoot := C.CString("root") + defer C.free(unsafe.Pointer(cGrammar)) + defer C.free(unsafe.Pointer(cRoot)) + + sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)} + + return sampler +} + +func (s *Sampler) Accept(token int32) { + C.llama_sampler_accept(s.c, C.llama_token(token)) +} + +type TokenData struct { + Id int32 + Logit float32 +} + +func (s *Sampler) Apply(tokens []TokenData) { + tds := make([]C.struct_llama_token_data, len(tokens)) + for i, token := range tokens { + tds[i] = C.struct_llama_token_data{ + id: C.int32_t(token.Id), + logit: C.float(token.Logit), + p: C.float(0.0), + } + } + tda := &C.llama_token_data_array{ + data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])), + size: C.size_t(len(tokens)), + selected: C.int64_t(-1), + sorted: C.bool(false), + } + + var pinner runtime.Pinner + pinner.Pin(&tds[0]) + defer pinner.Unpin() + + C.llama_sampler_apply(s.c, tda) + for i := range tokens { + tokens[i].Logit = float32(tds[i].logit) + } +} diff --git a/llama/patches/0015-try-catch-backend-load.patch b/llama/patches/0015-try-catch-backend-load.patch deleted file mode 100644 index 9aea61836..000000000 --- a/llama/patches/0015-try-catch-backend-load.patch +++ /dev/null @@ -1,69 +0,0 @@ -From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Michael Yang -Date: Tue, 11 Feb 2025 14:06:36 -0800 -Subject: [PATCH] try/catch backend load - ---- - ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++----------------- - 1 file changed, 23 insertions(+), 22 deletions(-) - -diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 98d5e14d..1c19129a 100644 ---- a/ggml/src/ggml-backend-reg.cpp -+++ b/ggml/src/ggml-backend-reg.cpp -@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, - } - fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); - for (const auto & entry : dir_it) { -- if (entry.is_regular_file()) { -- std::wstring filename = entry.path().filename().wstring(); -- std::wstring ext = entry.path().extension().wstring(); -- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { -- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; -- if (!handle && !silent) { -- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -- } -- if (handle) { -+ try { -+ if (entry.is_regular_file()) { -+ std::wstring filename = entry.path().filename().wstring(); -+ std::wstring ext = entry.path().extension().wstring(); -+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { -+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; -+ if (!handle) { -+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ continue; -+ } -+ - auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); -- if (score_fn) { -- int s = score_fn(); --#ifndef NDEBUG -- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); --#endif -- if (s > best_score) { -- best_score = s; -- best_path = entry.path().wstring(); -- } -- } else { -- if (!silent) { -- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -- } -+ if (!score_fn) { -+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ continue; -+ } -+ -+ int s = score_fn(); -+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); -+ if (s > best_score) { -+ best_score = s; -+ best_path = entry.path().wstring(); - } - } - } -+ } catch (const std::exception & e) { -+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); - } - } - } diff --git a/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch similarity index 67% rename from llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch rename to llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch index d60066c13..e72d78ac8 100644 --- a/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch +++ b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch @@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500 Subject: [PATCH] use std::filesystem::path instead of wstring --- - ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++-------------------- - 1 file changed, 58 insertions(+), 86 deletions(-) + ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++------------------- + 1 file changed, 88 insertions(+), 111 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 1c19129a..c854e6bb 100644 +index 98d5e14d..799af5f3 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -66,26 +66,6 @@ @@ -264,47 +264,55 @@ index 1c19129a..c854e6bb 100644 for (const auto & search_path : search_paths) { if (!fs::exists(search_path)) { continue; -@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, +@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, + fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - try { - if (entry.is_regular_file()) { -- std::wstring filename = entry.path().filename().wstring(); -- std::wstring ext = entry.path().extension().wstring(); -+ std::string filename = entry.path().filename().string(); -+ std::string ext = entry.path().extension().string(); - if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { -- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; -+ dl_handle_ptr handle { dl_load_library(entry.path()) }; - if (!handle) { -- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } - - auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); - if (!score_fn) { -- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } - - int s = score_fn(); -- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); -+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); - if (s > best_score) { - best_score = s; -- best_path = entry.path().wstring(); -+ best_path = entry.path(); - } + if (entry.is_regular_file()) { +- std::wstring filename = entry.path().filename().wstring(); +- std::wstring ext = entry.path().extension().wstring(); ++ std::string filename = entry.path().filename().string(); ++ std::string ext = entry.path().extension().string(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { +- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; +- if (!handle && !silent) { +- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); ++ dl_handle_ptr handle { dl_load_library(entry.path()) }; ++ if (!handle) { ++ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); ++ continue; + } +- if (handle) { +- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); +- if (score_fn) { +- int s = score_fn(); +-#ifndef NDEBUG +- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); +-#endif +- if (s > best_score) { +- best_score = s; +- best_path = entry.path().wstring(); +- } +- } else { +- if (!silent) { +- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); +- } +- } ++ ++ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); ++ if (!score_fn) { ++ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); ++ continue; ++ } ++ ++ int s = score_fn(); ++ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); ++ if (s > best_score) { ++ best_score = s; ++ best_path = entry.path(); } } - } catch (const std::exception & e) { -- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); -+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); } - } - } -@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, +@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (best_score == 0) { // try to load the base backend for (const auto & search_path : search_paths) { @@ -313,3 +321,49 @@ index 1c19129a..c854e6bb 100644 if (fs::exists(path)) { return get_reg().load_backend(path, silent); } +@@ -560,6 +529,14 @@ void ggml_backend_load_all() { + ggml_backend_load_all_from_path(nullptr); + } + ++static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) { ++ try { ++ ggml_backend_load_best(name, silent, user_search_path); ++ } catch (const std::exception & e) { ++ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what()); ++ } ++} ++ + void ggml_backend_load_all_from_path(const char * dir_path) { + #ifdef NDEBUG + bool silent = true; +@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) { + bool silent = false; + #endif + +- ggml_backend_load_best("blas", silent, dir_path); +- ggml_backend_load_best("cann", silent, dir_path); +- ggml_backend_load_best("cuda", silent, dir_path); +- ggml_backend_load_best("hip", silent, dir_path); +- ggml_backend_load_best("kompute", silent, dir_path); +- ggml_backend_load_best("metal", silent, dir_path); +- ggml_backend_load_best("rpc", silent, dir_path); +- ggml_backend_load_best("sycl", silent, dir_path); +- ggml_backend_load_best("vulkan", silent, dir_path); +- ggml_backend_load_best("opencl", silent, dir_path); +- ggml_backend_load_best("musa", silent, dir_path); +- ggml_backend_load_best("cpu", silent, dir_path); ++ ggml_backend_try_load_best("blas", silent, dir_path); ++ ggml_backend_try_load_best("cann", silent, dir_path); ++ ggml_backend_try_load_best("cuda", silent, dir_path); ++ ggml_backend_try_load_best("hip", silent, dir_path); ++ ggml_backend_try_load_best("kompute", silent, dir_path); ++ ggml_backend_try_load_best("metal", silent, dir_path); ++ ggml_backend_try_load_best("rpc", silent, dir_path); ++ ggml_backend_try_load_best("sycl", silent, dir_path); ++ ggml_backend_try_load_best("vulkan", silent, dir_path); ++ ggml_backend_try_load_best("opencl", silent, dir_path); ++ ggml_backend_try_load_best("musa", silent, dir_path); ++ ggml_backend_try_load_best("cpu", silent, dir_path); + // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend + const char * backend_path = std::getenv("GGML_BACKEND_PATH"); + if (backend_path) { diff --git a/llama/patches/0017-remove-amx.patch b/llama/patches/0016-remove-amx.patch similarity index 100% rename from llama/patches/0017-remove-amx.patch rename to llama/patches/0016-remove-amx.patch diff --git a/llama/patches/0018-fix-clip-compiler-error.patch b/llama/patches/0017-fix-clip-compiler-error.patch similarity index 100% rename from llama/patches/0018-fix-clip-compiler-error.patch rename to llama/patches/0017-fix-clip-compiler-error.patch diff --git a/llama/patches/0019-add-phi4-support.patch b/llama/patches/0018-add-phi4-support.patch similarity index 100% rename from llama/patches/0019-add-phi4-support.patch rename to llama/patches/0018-add-phi4-support.patch diff --git a/llama/patches/0019-fix-string-arr-kv-loading.patch b/llama/patches/0019-fix-string-arr-kv-loading.patch new file mode 100644 index 000000000..aa7b4d3cf --- /dev/null +++ b/llama/patches/0019-fix-string-arr-kv-loading.patch @@ -0,0 +1,64 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: jmorganca +Date: Wed, 5 Mar 2025 17:41:07 -0800 +Subject: [PATCH] fix string arr kv loading + +--- + ggml/include/gguf.h | 1 + + ggml/src/gguf.cpp | 7 +++++-- + src/llama-vocab.cpp | 2 +- + 3 files changed, 7 insertions(+), 3 deletions(-) + +diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h +index 79ee2020..3efb22f0 100644 +--- a/ggml/include/gguf.h ++++ b/ggml/include/gguf.h +@@ -114,6 +114,7 @@ extern "C" { + // get raw pointer to the first element of the array with the given key_id + // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); ++ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id); + + // get ith C string from array with given key_id + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); +diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp +index ab13669c..f75b923f 100644 +--- a/ggml/src/gguf.cpp ++++ b/ggml/src/gguf.cpp +@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id + + const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); +- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); + } + ++size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) { ++ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); ++ return ctx->kv[key_id].data.size(); ++} ++ + const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); +@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { + const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); +- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); + } + +diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp +index c7ff28be..7a185443 100644 +--- a/src/llama-vocab.cpp ++++ b/src/llama-vocab.cpp +@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { +- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); ++ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx); + const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); + #ifdef IS_BIG_ENDIAN diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 0f137dc8d..b816cedd4 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -2,6 +2,9 @@ #include "sampling.h" #include "sampling_ext.h" #include "json-schema-to-grammar.h" +#include "llama.h" +#include "llama-model.h" +#include "llama-model-loader.h" struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) { try { @@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len) return 0; } } + +struct llama_vocab * llama_load_vocab_from_file(const char * fname) { + llama_vocab * vocab = new llama_vocab(); + try { + const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); + std::vector splits = {}; + llama_model_loader ml(std::string(fname), splits, false, false, nullptr); + vocab->load(ml, kv); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); + return nullptr; + } + + return vocab; +} + +void llama_free_vocab(struct llama_vocab * vocab) { + delete vocab; +} diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index 39f499f19..9be7c100e 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -35,6 +35,9 @@ extern "C" int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len); + struct llama_vocab * llama_load_vocab_from_file(const char * fname); + void llama_free_vocab(struct llama_vocab * vocab); + #ifdef __cplusplus } #endif diff --git a/llm/memory.go b/llm/memory.go index 1da4d2c08..40104eca9 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin // multimodal models require at least 2048 context opts.NumCtx = max(opts.NumCtx, 2048) } + if projectorWeights == 0 && projectorGraph == 0 { + projectorWeights, projectorGraph = f.VisionGraphSize() + } layers := f.Tensors().GroupLayers() // add one layer worth of memory as a buffer diff --git a/llm/server.go b/llm/server.go index fd027a535..a53306fb0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -30,6 +30,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/model" ) type LlamaServer interface { @@ -54,8 +55,15 @@ type llmServer struct { options api.Options numParallel int modelPath string - modelLock sync.Mutex // Temporary until we switch fully to Go server - model *llama.Model // If non-nil, the runner is a new Go server + + // llamaModel is an instance of the cgo llama.cpp model definition + // nil if this server is running the new engine + llamaModel *llama.Model + llamaModelLock sync.Mutex + + // textProcessor handles text encoding/decoding for the model in the Ollama engine + // nil if this server is running the llama.cpp based engine + textProcessor model.TextProcessor estimate MemoryEstimate totalLayers uint64 @@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) { // NewLlamaServer will run a server for the given GPUs // The gpu list must be a single family. -func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { +func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { systemInfo := discover.GetSystemInfo() systemTotalMemory := systemInfo.System.TotalMemory systemFreeMemory := systemInfo.System.FreeMemory @@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt slog.Info("offload", "", estimate) params := []string{ - "--model", model, + "--model", modelPath, "--ctx-size", strconv.Itoa(opts.NumCtx), "--batch-size", strconv.Itoa(opts.NumBatch), } @@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt } } - if len(projectors) > 0 { - // TODO: applying multiple projectors is not supported by the llama.cpp server yet - params = append(params, "--mmproj", projectors[0]) - } - defaultThreads := systemInfo.GetOptimalThreadCount() if opts.NumThread > 0 { params = append(params, "--threads", strconv.Itoa(opts.NumThread)) @@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt } } slog.Debug("compatible gpu libraries", "compatible", compatible) + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("unable to lookup executable path: %w", err) + } + + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + + var llamaModel *llama.Model + var textProcessor model.TextProcessor + if envconfig.NewEngine() { + textProcessor, err = model.NewTextProcessor(modelPath) + if err != nil { + // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner + slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) + } + } + if textProcessor == nil { + llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true}) + if err != nil { + return nil, err + } + } + + if len(projectors) > 0 && llamaModel != nil { + params = append(params, "--mmproj", projectors[0]) + } // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc. // adding each library's respective path to the LD_LIBRARY_PATH, until finally running @@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } finalParams := []string{"runner"} - if envconfig.NewEngine() { + if textProcessor != nil { + // New engine + // TODO - if we have failure to load scenarios, add logic to retry with the old runner finalParams = append(finalParams, "--ollama-engine") } finalParams = append(finalParams, params...) @@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt // finally, add the root library path libraryPaths = append(libraryPaths, discover.LibOllamaPath) - exe, err := os.Executable() - if err != nil { - return nil, fmt.Errorf("unable to lookup executable path: %w", err) - } - - if eval, err := filepath.EvalSymlinks(exe); err == nil { - exe = eval - } - - // TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access s := &llmServer{ - port: port, - cmd: exec.Command(exe, finalParams...), - status: NewStatusWriter(os.Stderr), - options: opts, - modelPath: model, - estimate: estimate, - numParallel: numParallel, - sem: semaphore.NewWeighted(int64(numParallel)), - totalLayers: f.KV().BlockCount() + 1, - gpus: gpus, - done: make(chan error, 1), + port: port, + cmd: exec.Command(exe, finalParams...), + status: NewStatusWriter(os.Stderr), + options: opts, + modelPath: modelPath, + llamaModel: llamaModel, + textProcessor: textProcessor, + estimate: estimate, + numParallel: numParallel, + sem: semaphore.NewWeighted(int64(numParallel)), + totalLayers: f.KV().BlockCount() + 1, + gpus: gpus, + done: make(chan error, 1), } s.cmd.Env = os.Environ() @@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt } err := fmt.Errorf("error starting runner: %v %s", err, msg) if len(compatible) == 0 { + if llamaModel != nil { + llama.FreeModel(llamaModel) + } return nil, err } @@ -933,64 +961,25 @@ type TokenizeResponse struct { } func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { - s.modelLock.Lock() - defer s.modelLock.Unlock() - if s.model != nil { - return s.model.Tokenize(content, false, true) - } + s.llamaModelLock.Lock() + defer s.llamaModelLock.Unlock() - // Make sure the server is ready - status, err := s.getServerStatus(ctx) - if err != nil { - return nil, err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + if s.llamaModel != nil { + return s.llamaModel.Tokenize(content, false, true) } - - data, err := json.Marshal(TokenizeRequest{Content: content}) - if err != nil { - return nil, fmt.Errorf("marshaling encode data: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data)) - if err != nil { - return nil, fmt.Errorf("encode request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("do encode request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusNotFound { - if s.model == nil { - slog.Debug("new runner detected, loading model for cgo tokenization") - m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) - if err != nil { - return nil, err - } - s.model = m + if s.textProcessor != nil { + tokens, err := s.textProcessor.Encode(content, false) + if err != nil { + return nil, err } - return s.model.Tokenize(content, false, true) + toks := make([]int, len(tokens)) + for i, t := range tokens { + toks[i] = int(t) + } + return toks, nil } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read encode request: %w", err) - } - - if resp.StatusCode >= 400 { - log.Printf("llm encode error: %s", body) - return nil, fmt.Errorf("%s", body) - } - - var encoded TokenizeResponse - if err := json.Unmarshal(body, &encoded); err != nil { - return nil, fmt.Errorf("unmarshal encode response: %w", err) - } - - return encoded.Tokens, nil + // not reached + return nil, fmt.Errorf("no tokenizer configured") } type DetokenizeRequest struct { @@ -1002,80 +991,38 @@ type DetokenizeResponse struct { } func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { - s.modelLock.Lock() - defer s.modelLock.Unlock() - if s.model != nil { + s.llamaModelLock.Lock() + defer s.llamaModelLock.Unlock() + + if s.llamaModel != nil { var resp string for _, token := range tokens { - resp += s.model.TokenToPiece(token) + resp += s.llamaModel.TokenToPiece(token) } return resp, nil } - // Make sure the server is ready - status, err := s.getServerStatus(ctx) - if err != nil { - return "", err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { - return "", fmt.Errorf("unexpected server status: %s", status.ToString()) - } - - data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) - if err != nil { - return "", fmt.Errorf("marshaling decode data: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data)) - if err != nil { - return "", fmt.Errorf("decode request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", fmt.Errorf("do decode request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusNotFound { - if s.model == nil { - slog.Debug("new runner detected, loading model for cgo tokenization") - m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) - if err != nil { - return "", err - } - s.model = m + if s.textProcessor != nil { + toks := make([]int32, len(tokens)) + for i, t := range tokens { + toks[i] = int32(t) } - var resp string - for _, token := range tokens { - resp += s.model.TokenToPiece(token) + content, err := s.textProcessor.Decode(toks) + if err != nil { + return "", err } - return resp, nil + return content, nil } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("read decode request: %w", err) - } - - if resp.StatusCode >= 400 { - log.Printf("llm decode error: %s", body) - return "", fmt.Errorf("%s", body) - } - - var decoded DetokenizeResponse - if err := json.Unmarshal(body, &decoded); err != nil { - return "", fmt.Errorf("unmarshal encode response: %w", err) - } - - return decoded.Content, nil + // not reached + return "", fmt.Errorf("no tokenizer configured") } func (s *llmServer) Close() error { - s.modelLock.Lock() - if s.model != nil { - llama.FreeModel(s.model) - s.model = nil + s.llamaModelLock.Lock() + if s.llamaModel != nil { + llama.FreeModel(s.llamaModel) + s.llamaModel = nil } - s.modelLock.Unlock() + s.llamaModelLock.Unlock() if s.cmd != nil { slog.Debug("stopping llama server") diff --git a/ml/backend.go b/ml/backend.go index 07bc75b64..641175f0f 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "os" + "slices" "strconv" "strings" ) @@ -24,7 +25,36 @@ type Backend interface { Config() Config Get(name string) Tensor NewContext() Context - SystemInfo() string + NewContextSize(size int) Context +} + +// BackendCacheConfig should be implemented by backends that need special output +// from the cache to meet specific requirements. It is frequently implemented in +// conjunction with ScaledDotProductAttention. +type BackendCacheConfig interface { + CacheConfig() CacheConfig +} + +// CacheConfig controls optimizations (mostly backend-specific) that may transform +// the output the cache to work better with specific kernels. +type CacheConfig struct { + // CachePadding specifies the multiple for the number of tokens of cache history + // that will be returned from cache Get for k, v and mask. The capacity of the + // cache itself will also be increased to a multiple of this size if needed. + CachePadding int + + // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put + // and return the permuted version via Get. This uses the cache copy operation + // to avoid a Contiguous call on the permuted tensor. + PermutedV bool + + // MaskDType specifies the data type for generating the mask. If unset it will + // default to DTypeF32. + MaskDType DType + + // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. + // Any position that does not correspond to an actual token will be filled with -Inf. + MaskBatchPadding int } // BackendParams controls how the backend loads and executes models @@ -40,6 +70,9 @@ type BackendParams struct { // TensorSplit is the fraction of the model to offload to each GPU TensorSplit []float32 + + // FlashAttention indicates that we should use a fused flash attention kernel + FlashAttention bool } var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) @@ -61,14 +94,24 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) { } type Context interface { + Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) Forward(...Tensor) Context Compute(...Tensor) - MaxTensors() int + MaxGraphNodes() int Close() + + // Input returns a context appropriate for creating input tensors + Input() Context + + // Output returns a context appropriate for creating output tensors + Output() Context + + // Layer returns a context appropriate for creating intermediate tensors + Layer(int) Context } type Tensor interface { @@ -116,6 +159,10 @@ type Tensor interface { // operation equivalent to following code on a tensor named // query: // +// query = query.Permute(ctx, 0, 2, 1, 3) +// key = key.Permute(ctx, 0, 2, 1, 3) +// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) +// // kq := key.MulmatFullPrec(ctx, query) // // kq = kq.Scale(ctx, scale) @@ -169,8 +216,8 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) - case DTypeF16: - f32 := ctx.Zeros(DTypeF32, t.Shape()...) + case DTypeF16, DTypeQ80, DTypeQ40: + f32 := ctx.Empty(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) @@ -195,16 +242,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) } shape := t.Shape() + slices.Reverse(shape) var sb strings.Builder var f func([]int, int) f = func(dims []int, stride int) { prefix := strings.Repeat(" ", len(shape)-len(dims)+1) - fmt.Fprint(&sb, "[") - defer func() { fmt.Fprint(&sb, "]") }() + sb.WriteString("[") + defer func() { sb.WriteString("]") }() for i := 0; i < dims[0]; i++ { if i >= items && i < dims[0]-items { - fmt.Fprint(&sb, "..., ") + sb.WriteString("..., ") // skip to next printable element skip := dims[0] - 2*items if len(dims) > 1 { @@ -219,9 +267,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) } } else { - fmt.Fprint(&sb, fn(s[stride+i])) + text := fn(s[stride+i]) + if len(text) > 0 && text[0] != '-' { + sb.WriteString(" ") + } + + sb.WriteString(text) if i < dims[0]-1 { - fmt.Fprint(&sb, ", ") + sb.WriteString(", ") } } } @@ -237,5 +290,7 @@ const ( DTypeOther DType = iota DTypeF32 DTypeF16 + DTypeQ80 + DTypeQ40 DTypeI32 ) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 7f91990c3..74512f337 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1,89 +1,61 @@ package ggml -/* -#cgo CPPFLAGS: -I${SRCDIR}/ggml/include -#include -#include -#include "ggml.h" -#include "ggml-cpu.h" -#include "ggml-backend.h" -static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);} -static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];} - -typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER; -COMPILER inline get_compiler() { -#if defined(__clang__) - return COMP_CLANG; -#elif defined(__GNUC__) - return COMP_GCC; -#else - return UNKNOWN_COMPILER; -#endif -} - -*/ +// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include +// #include +// #include +// #include "ggml.h" +// #include "ggml-cpu.h" +// #include "ggml-backend.h" import "C" import ( + "errors" "fmt" "io" "log/slog" + "maps" "os" - "sync" + "slices" + "strconv" + "strings" + "unicode" "unsafe" "github.com/ollama/ollama/format" fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" - "golang.org/x/sync/errgroup" - ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + "golang.org/x/sync/errgroup" ) -type device struct { - d *C.struct_ggml_backend_device -} - -func (d device) LogValue() slog.Value { - var free, total uint64 - C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total)) - - kind := "unknown" - switch C.ggml_backend_dev_type(d.d) { - case C.GGML_BACKEND_DEVICE_TYPE_CPU: - kind = "cpu" - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - kind = "gpu" - case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: - kind = "accel" - } - - return slog.GroupValue( - slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))), - slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))), - slog.String("kind", kind), - slog.String("free", format.HumanBytes2(free)), - slog.String("total", format.HumanBytes2(total)), - ) -} - -var devices = sync.OnceValue(func() []device { +func devices() []*C.struct_ggml_backend_device { ggml.OnceLoad() - - s := make([]device, C.ggml_backend_dev_count()) - for i := range s { - s[i] = device{C.ggml_backend_dev_get(C.size_t(i))} + ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count()) + for i := range ds { + ds[i] = C.ggml_backend_dev_get(C.size_t(i)) } - return s -}) + return ds +} type Backend struct { - meta *fs.GGML - cpus, gpus []Context - tensors map[string]*Context + meta *fs.GGML + sched *C.struct_ggml_backend_sched + tensors map[string]*C.struct_ggml_tensor - sched *C.struct_ggml_backend_sched + // input is the backend used for inputs + input *C.struct_ggml_backend_buffer_type + + // output is the backend used for outputs + output *C.struct_ggml_backend_buffer_type + + // layers is the backend used for repeating layers + layers map[int]*C.struct_ggml_backend_buffer_type + + flashAttention bool + + // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler + maxGraphNodes int } func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { @@ -102,106 +74,310 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { "num_key_values", len(meta.KV()), ) - var cpus, gpus []Context + type deviceBufferType struct { + d *C.struct_ggml_backend_device + bts []*C.struct_ggml_backend_buffer_type + } + + var cpus, accels, gpus []*C.struct_ggml_backend_device for _, d := range devices() { - switch C.ggml_backend_dev_type(d.d) { + switch C.ggml_backend_dev_type(d) { + case C.GGML_BACKEND_DEVICE_TYPE_CPU: + if len(cpus) == 0 { + // only the first cpu device should be used + cpus = append(cpus, d) + } + case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: + accels = append(accels, d) + case C.GGML_BACKEND_DEVICE_TYPE_GPU: + gpus = append(gpus, d) + } + } + + // create list of buffer types for the cpu + cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)} + for _, d := range append(accels, append(gpus, cpus...)...) { + switch C.ggml_backend_dev_type(d) { case C.GGML_BACKEND_DEVICE_TYPE_CPU, C.GGML_BACKEND_DEVICE_TYPE_ACCEL: - slog.Info("cpu", "device", d) - cpus = append(cpus, Context{ - ctx: C.ggml_init(C.struct_ggml_init_params{ - mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)), - no_alloc: true, - }), - backend: C.ggml_backend_dev_init(d.d, nil), - }) - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - slog.Info("gpu", "device", d) - gpus = append(gpus, Context{ - ctx: C.ggml_init(C.struct_ggml_init_params{ - mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)), - no_alloc: true, - }), - backend: C.ggml_backend_dev_init(d.d, nil), - }) + cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d)) } } - ctxFunc := func(s []Context) (*Context, error) { - for _, e := range s { - return &e, nil - } - - return nil, fmt.Errorf("no devices available") - } - - tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items())) - for _, t := range meta.Tensors().Items() { - c, err := ctxFunc(append(gpus, cpus...)) - if err != nil { - return nil, err - } - - func() { - tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0]))) - - cname := C.CString(t.Name) - defer C.free(unsafe.Pointer(cname)) - C.ggml_set_name(tt, cname) - - tensors[t] = c - }() - } - - for _, b := range append(gpus, cpus...) { - C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend) - } - - sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) - - var g errgroup.Group - for t, c := range tensors { - g.Go(func() error { - bts := make([]byte, t.Size()) - n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts) - if err != nil { - return err - } - - if n != int(t.Size()) { - return fmt.Errorf("expected %d bytes, got %d", t.Size(), n) - } - - cname := C.CString(t.Name) - defer C.free(unsafe.Pointer(cname)) - - C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n)) - return nil + // create list of buffer types for each gpu + var gpuDeviceBufferTypes []deviceBufferType + for _, d := range gpus { + bt := C.ggml_backend_dev_buffer_type(d) + gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{ + d: d, + bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...), }) } - if err := g.Wait(); err != nil { + useDefaultSplit := true + for _, s := range params.TensorSplit { + if s != 0 { + useDefaultSplit = false + break + } + } + + // calculate splits + splits := make([]float32, len(gpus)) + if useDefaultSplit { + // default: split on free memory + for i := range splits { + var free, total C.size_t + C.ggml_backend_dev_memory(gpus[i], &free, &total) + splits[i] = float32(free) + } + } else { + splits = params.TensorSplit + } + + var sum float32 + // cumulative sum of all splits + for i := range splits { + sum += splits[i] + splits[i] = sum + } + + // normalize splits + for i := range splits { + splits[i] /= sum + } + + // inputs always use cpu + input := cpuDeviceBufferType + + blocks := int(meta.KV().BlockCount()) + + // define a range of gpu layers. anything outside of this range is assigned to the cpu + gpuRangeStart := max(0, blocks-params.NumGPULayers) + gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1) + assignLayer := func(i int) deviceBufferType { + if i < gpuRangeStart || i >= gpuRangeStop { + return cpuDeviceBufferType + } + + index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f }) + if index < 0 || index >= len(gpuDeviceBufferTypes) { + return cpuDeviceBufferType + } + + return gpuDeviceBufferTypes[index] + } + + // repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1) + layers := make([]deviceBufferType, blocks) + for i := range layers { + layers[i] = assignLayer(i) + } + + // outputs are assigned iff allowed by splits and configured number of gpu layers + output := assignLayer(blocks) + + maxTensors := len(meta.Tensors().Items()) + maxTensors += 1 + // each layer has at most 2 extra tensors for rope operations + maxTensors += blocks * 2 + + type tensor struct { + source *fs.Tensor + target string + } + + // some tensors are mapped to different names so keep a list + targets := make(map[string][]string) + + // contexts are shared by tensors of the same buffer type + ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context) + createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor { + for _, bt := range bts { + if _, ok := ctxs[bt]; !ok { + ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ + mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors), + no_alloc: true, + }) + } + + targets[t.source.Name] = append(targets[t.source.Name], t.target) + + name := t.source.Name + if t.target != "" { + name = t.target + } + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil { + return tt + } + + tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0]))) + C.ggml_set_name(tt, cname) + + slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) + //nolint:staticcheck // TODO: check if buffer type supports this tensor + return tt + } + + return nil + } + + contains := func(s string, parts ...string) bool { + split := strings.Split(s, ".") + for _, part := range parts { + if slices.Contains(split, part) { + return true + } + } + + return false + } + + for _, t := range meta.Tensors().Items() { + switch { + case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): + createTensor(tensor{source: t}, input.bts) + case contains(t.Name, "cls", "output", "output_norm"): + createTensor(tensor{source: t}, output.bts) + case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): + // TODO: assign vision tensors to the gpu if possible + createTensor(tensor{source: t}, input.bts) + default: + layerIndex := -1 + if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 { + if i, err := strconv.Atoi(fields[0]); err == nil { + layerIndex = i + } + } + + if layerIndex >= 0 { + createTensor(tensor{source: t}, layers[layerIndex].bts) + } else { + // this is a repeating tensor that doesn't explicitly associated with a layer so + // duplicate it for each layer + for i, layer := range layers { + createTensor(tensor{ + source: t, + target: "blk." + strconv.Itoa(i) + "." + t.Name, + }, layer.bts) + } + } + } + } + + // allocate buffers for each context + bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs)) + for bt, c := range ctxs { + if C.ggml_get_first_tensor(c) == nil { + continue + } + + b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) + C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) + bbs[c] = b + } + + for bs := range maps.Values(bbs) { + slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) + } + + // map tensor names to tensors for easy lookup later + tensors := make(map[string]*C.struct_ggml_tensor) + for _, c := range ctxs { + for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) { + tensors[C.GoString(C.ggml_get_name(t))] = t + } + } + + // concurrently read in tensor data. uses a section reader which is safe for concurrent reads + sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) + var g errgroup.Group + for _, t := range meta.Tensors().Items() { + for _, target := range targets[t.Name] { + g.Go(func() error { + if target == "" { + target = t.Name + } + + tt, ok := tensors[target] + if !ok { + return fmt.Errorf("unassigned tensor: %s", t.Name) + } + + bts := make([]byte, t.Size()) + n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts) + if err != nil { + return err + } + + if n != len(bts) { + return errors.New("short read") + } + + C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size())) + return nil + }) + } + } + + if g.Wait() != nil { return nil, err } - backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus)) - bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus)) - for i, c := range append(gpus, cpus...) { - backends[i] = c.backend - bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend) + // map devices to backend buffer types so new tensors can be assigned to the correct device + deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type) + + // create backends and buffer types used for the compute graph scheduler + var schedBackends []*C.struct_ggml_backend + var schedBufts []*C.struct_ggml_backend_buffer_type + for _, d := range append(gpus, append(accels, cpus...)...) { + b := C.ggml_backend_dev_init(d, nil) + bt := C.ggml_backend_get_default_buffer_type(b) + if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 { + // use the first gpu host buffer type for gpu if possible + if hbt := C.ggml_backend_dev_host_buffer_type(gpus[0]); hbt != nil { + bt = hbt + } + } + + deviceBufferTypes[d] = bt + + schedBackends = append(schedBackends, b) + schedBufts = append(schedBufts, bt) + + slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) + + if C.ggml_backend_is_cpu(b) { + // set number of threads for cpu backend + C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads)) + } } + maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) return &Backend{ - meta: meta, - cpus: cpus, - gpus: gpus, + flashAttention: params.FlashAttention, + meta: meta, + tensors: tensors, sched: C.ggml_backend_sched_new( - (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), - (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), - C.int(len(backends)), - C.size_t(max(8192, len(meta.Tensors().Items())*5)), + (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), + (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), + C.int(len(schedBackends)), + C.size_t(maxGraphNodes), true, ), + input: deviceBufferTypes[input.d], + output: deviceBufferTypes[output.d], + layers: func() map[int]*C.struct_ggml_backend_buffer_type { + m := make(map[int]*C.struct_ggml_backend_buffer_type) + for i, layer := range layers { + m[i] = deviceBufferTypes[layer.d] + } + return m + }(), + maxGraphNodes: maxGraphNodes, }, nil } @@ -214,51 +390,95 @@ func (b *Backend) Config() ml.Config { } func (b *Backend) Get(name string) ml.Tensor { - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - - for _, c := range append(b.gpus, b.cpus...) { - if t := C.ggml_get_tensor(c.ctx, cname); t != nil { - return &Tensor{t: t} - } + if t, ok := b.tensors[name]; ok { + return &Tensor{b: b, t: t} } return nil } func (b *Backend) NewContext() ml.Context { - nodes := max(8192, len(b.meta.Tensors().Items())*5) - c := C.ggml_init(C.struct_ggml_init_params{ - mem_buffer: nil, - mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false), - no_alloc: true, - }) + return b.NewContextSize(b.maxGraphNodes) +} - backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus)) - for i, c := range append(b.gpus, b.cpus...) { - backends[i] = c.backend +func (b *Backend) NewContextSize(n int) ml.Context { + if n > b.maxGraphNodes { + panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes)) } return &Context{ - b: b, - ctx: c, - backend: backends[0], - nodes: nodes, + b: b, + maxGraphNodes: n, + ctx: C.ggml_init(C.struct_ggml_init_params{ + mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false), + no_alloc: true, + }), + } +} + +func (b *Backend) CacheConfig() ml.CacheConfig { + if b.flashAttention { + return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} + } else { + return ml.CacheConfig{CachePadding: 32, PermutedV: true} } } type Context struct { - b *Backend - ctx *C.struct_ggml_context - backend *C.struct_ggml_backend + b *Backend + ctx *C.struct_ggml_context graph *C.struct_ggml_cgraph - nodes int + + // buft is the buffer type used for new tensors + buft *C.struct_ggml_backend_buffer_type + + // maxGraphNodes is the maximum allowed number of graph nodes in this context + maxGraphNodes int +} + +func (c Context) Input() ml.Context { + if c.b.input != nil { + return &Context{ + b: c.b, + ctx: c.ctx, + buft: c.b.input, + maxGraphNodes: c.maxGraphNodes, + } + } + + return &c +} + +func (c Context) Output() ml.Context { + if c.b.output != nil { + return &Context{ + b: c.b, + ctx: c.ctx, + buft: c.b.output, + maxGraphNodes: c.maxGraphNodes, + } + } + + return &c +} + +func (c Context) Layer(i int) ml.Context { + if buft, ok := c.b.layers[i]; ok { + return &Context{ + b: c.b, + ctx: c.ctx, + buft: buft, + maxGraphNodes: c.maxGraphNodes, + } + } + + return &c } func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { if c.graph == nil { - c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false) + c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false) } for _, tensor := range tensors { @@ -268,7 +488,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { return c } -func (c *Context) Compute(tensors ...ml.Tensor) { +func (c Context) Compute(tensors ...ml.Tensor) { C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph) C.ggml_backend_sched_reset(c.b.sched) @@ -287,21 +507,48 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } -func (c *Context) MaxTensors() int { - return c.nodes +func (c Context) MaxGraphNodes() int { + return c.maxGraphNodes } func shapeToGGML(shape []int) *C.int64_t { sh := make([]C.int64_t, len(shape)) for i, s := range shape { - sh[i] = (C.int64_t)(s) + sh[i] = C.int64_t(s) } return &sh[0] } -func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { - if len(shape) < 1 || len(shape) > 4 { +func pad(length, pad C.size_t) C.size_t { + return ((length + pad - 1) / pad) * pad +} + +func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { + if c.buft == nil { + panic("set Input, Output, or Layer before creating tensors") + } + + var cdtype uint32 + switch dtype { + case ml.DTypeF32: + cdtype = C.GGML_TYPE_F32 + case ml.DTypeF16: + cdtype = C.GGML_TYPE_F16 + case ml.DTypeQ80: + cdtype = C.GGML_TYPE_Q8_0 + case ml.DTypeQ40: + cdtype = C.GGML_TYPE_Q4_0 + case ml.DTypeI32: + cdtype = C.GGML_TYPE_I32 + default: + panic("unsupported dtype") + } + + if len(shape) < 1 || shape[0] == 0 { + var shape C.int64_t = 0 + return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)} + } else if len(shape) > 4 { panic("unsupported number of dimensions") } @@ -311,31 +558,28 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { } } - var t *C.struct_ggml_tensor - switch dtype { - case ml.DTypeF32: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) - case ml.DTypeF16: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) - case ml.DTypeI32: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) - default: - panic("unsupported dtype") - } - - b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t)) + t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) + size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) + b := C.ggml_backend_buft_alloc_buffer(c.buft, size) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - C.ggml_set_zero(t) - return &Tensor{t: t} + return &Tensor{b: c.b, t: t} } -func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { +func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { + return c.newTensor(dtype, shape) +} + +func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + t := c.newTensor(dtype, shape) + C.ggml_set_zero(t.(*Tensor).t) + return t +} + +func checkShape[S ~[]E, E any](s S, shape ...int) error { n := len(s) if n == 0 { - var shape C.int64_t = 0 - t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape) - return &Tensor{t: t}, nil + return nil } for _, v := range shape { @@ -343,22 +587,36 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u } if n != 1 { - return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s)) + return fmt.Errorf("invalid shape: %v", shape) } - t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape)) - b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) - C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t)) - return &Tensor{t: t}, nil + return nil } func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - return fromSlice(c, s, shape, C.GGML_TYPE_F32) + if err := checkShape(s, shape...); err != nil { + return nil, err + } + + t := c.newTensor(ml.DTypeF32, shape) + if len(s) > 0 { + C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + } + + return t, nil } func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { - return fromSlice(c, s, shape, C.GGML_TYPE_I32) + if err := checkShape(s, shape...); err != nil { + return nil, err + } + + t := c.newTensor(ml.DTypeI32, shape) + if len(s) > 0 { + C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + } + + return t, nil } func (c *Context) Close() { @@ -368,6 +626,7 @@ func (c *Context) Close() { } type Tensor struct { + b *Backend t *C.struct_ggml_tensor sync func() } @@ -425,6 +684,10 @@ func (t *Tensor) DType() ml.DType { return ml.DTypeF32 case C.GGML_TYPE_F16: return ml.DTypeF16 + case C.GGML_TYPE_Q8_0: + return ml.DTypeQ80 + case C.GGML_TYPE_Q4_0: + return ml.DTypeQ40 case C.GGML_TYPE_I32: return ml.DTypeI32 default: @@ -434,6 +697,7 @@ func (t *Tensor) DType() ml.DType { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -448,24 +712,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), } } func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cont(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -475,12 +743,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) return &Tensor{ + b: t.b, t: mul, } } func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) if b != nil { tt = tt.Add(ctx, b) } @@ -489,7 +758,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { @@ -498,6 +767,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -508,18 +778,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -528,18 +801,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), } case 2: return &Tensor{ + b: t.b, t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), } case 3: return &Tensor{ + b: t.b, t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), } case 4: return &Tensor{ + b: t.b, t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), } default: @@ -549,18 +826,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), } } func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), } } @@ -571,6 +851,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -579,10 +860,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), } case 3: return &Tensor{ + b: t.b, t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.size_t(shape[1]), @@ -590,6 +873,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 5: return &Tensor{ + b: t.b, t: C.ggml_view_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.size_t(shape[1]), C.size_t(shape[3]), @@ -597,6 +881,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 7: return &Tensor{ + b: t.b, t: C.ggml_view_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]), C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]), @@ -613,7 +898,7 @@ const ( func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { - ropeFactors = &Tensor{} + ropeFactors = &Tensor{b: t.b} } dequant := t.t @@ -622,6 +907,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } return &Tensor{ + b: t.b, t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), @@ -639,18 +925,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), } } @@ -661,42 +950,23 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kqMask = mask.(*Tensor).t } - kq := key.MulmatFullPrec(ctx, t) - kq = &Tensor{ - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), - } + query := t.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) - kqv := value.Mulmat(ctx, kq) - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) -} + if t.b.flashAttention { + value = value.Permute(ctx, 0, 2, 1, 3) -func (b *Backend) SystemInfo() string { - var compiler string - switch C.get_compiler() { - case C.COMP_UNKNOWN: - compiler = "cgo(unknown_compiler)" - case C.COMP_GCC: - compiler = "cgo(gcc)" - case C.COMP_CLANG: - compiler = "cgo(clang)" - } - - var s string - for i := range C.ggml_backend_reg_count() { - reg := C.ggml_backend_reg_get(i) - fName := C.CString("ggml_backend_get_features") - defer C.free(unsafe.Pointer(fName)) - get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName) - if get_features_fn != nil { - s += C.GoString(C.ggml_backend_reg_name(reg)) - s += " : " - for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) { - s += C.GoString(features.name) - s += " = " - s += C.GoString(features.value) - s += " | " - } + kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) + C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) + return &Tensor{b: t.b, t: kqv} + } else { + kq := key.MulmatFullPrec(ctx, query) + kq = &Tensor{ + b: t.b, + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), } + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) } - return s + compiler } diff --git a/ml/backend/ggml/ggml/include/gguf.h b/ml/backend/ggml/ggml/include/gguf.h index 79ee20206..3efb22f01 100644 --- a/ml/backend/ggml/ggml/include/gguf.h +++ b/ml/backend/ggml/ggml/include/gguf.h @@ -114,6 +114,7 @@ extern "C" { // get raw pointer to the first element of the array with the given key_id // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); + GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id); // get ith C string from array with given key_id GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index c854e6bb2..799af5f3a 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -484,33 +484,29 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - try { - if (entry.is_regular_file()) { - std::string filename = entry.path().filename().string(); - std::string ext = entry.path().extension().string(); - if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { - dl_handle_ptr handle { dl_load_library(entry.path()) }; - if (!handle) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + std::string ext = entry.path().extension().string(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { + dl_handle_ptr handle { dl_load_library(entry.path()) }; + if (!handle) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); + continue; + } - auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); - if (!score_fn) { - GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); + if (!score_fn) { + GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); + continue; + } - int s = score_fn(); - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); - if (s > best_score) { - best_score = s; - best_path = entry.path(); - } + int s = score_fn(); + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); + if (s > best_score) { + best_score = s; + best_path = entry.path(); } } - } catch (const std::exception & e) { - GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); } } } @@ -533,6 +529,14 @@ void ggml_backend_load_all() { ggml_backend_load_all_from_path(nullptr); } +static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) { + try { + ggml_backend_load_best(name, silent, user_search_path); + } catch (const std::exception & e) { + GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what()); + } +} + void ggml_backend_load_all_from_path(const char * dir_path) { #ifdef NDEBUG bool silent = true; @@ -540,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) { bool silent = false; #endif - ggml_backend_load_best("blas", silent, dir_path); - ggml_backend_load_best("cann", silent, dir_path); - ggml_backend_load_best("cuda", silent, dir_path); - ggml_backend_load_best("hip", silent, dir_path); - ggml_backend_load_best("kompute", silent, dir_path); - ggml_backend_load_best("metal", silent, dir_path); - ggml_backend_load_best("rpc", silent, dir_path); - ggml_backend_load_best("sycl", silent, dir_path); - ggml_backend_load_best("vulkan", silent, dir_path); - ggml_backend_load_best("opencl", silent, dir_path); - ggml_backend_load_best("musa", silent, dir_path); - ggml_backend_load_best("cpu", silent, dir_path); + ggml_backend_try_load_best("blas", silent, dir_path); + ggml_backend_try_load_best("cann", silent, dir_path); + ggml_backend_try_load_best("cuda", silent, dir_path); + ggml_backend_try_load_best("hip", silent, dir_path); + ggml_backend_try_load_best("kompute", silent, dir_path); + ggml_backend_try_load_best("metal", silent, dir_path); + ggml_backend_try_load_best("rpc", silent, dir_path); + ggml_backend_try_load_best("sycl", silent, dir_path); + ggml_backend_try_load_best("vulkan", silent, dir_path); + ggml_backend_try_load_best("opencl", silent, dir_path); + ggml_backend_try_load_best("musa", silent, dir_path); + ggml_backend_try_load_best("cpu", silent, dir_path); // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend const char * backend_path = std::getenv("GGML_BACKEND_PATH"); if (backend_path) { diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go index 85c693eba..afc1e1edd 100644 --- a/ml/backend/ggml/ggml/src/ggml.go +++ b/ml/backend/ggml/ggml/src/ggml.go @@ -7,6 +7,20 @@ package ggml // #include // #include "ggml-backend.h" // extern void sink(int level, char *text, void *user_data); +// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); } +// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; } +/* +typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER; +static COMPILER compiler_name(void) { +#if defined(__clang__) + return COMPILER_CLANG; +#elif defined(__GNUC__) + return COMPILER_GNUC; +#else + return COMPILER_UNKNOWN; +#endif +} +*/ import "C" import ( @@ -16,6 +30,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "sync" "unsafe" @@ -90,4 +105,43 @@ var OnceLoad = sync.OnceFunc(func() { visited[abspath] = struct{}{} } } + + slog.Info("system", "", system{}) }) + +type system struct{} + +func (system) LogValue() slog.Value { + var attrs []slog.Attr + names := make(map[string]int) + for i := range C.ggml_backend_dev_count() { + r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i)) + + func() { + fName := C.CString("ggml_backend_get_features") + defer C.free(unsafe.Pointer(fName)) + + if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil { + var features []any + for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) { + features = append(features, C.GoString(f.name), C.GoString(f.value)) + } + + name := C.GoString(C.ggml_backend_reg_name(r)) + attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...)) + names[name] += 1 + } + }() + } + + switch C.compiler_name() { + case C.COMPILER_CLANG: + attrs = append(attrs, slog.String("compiler", "cgo(clang)")) + case C.COMPILER_GNUC: + attrs = append(attrs, slog.String("compiler", "cgo(gcc)")) + default: + attrs = append(attrs, slog.String("compiler", "cgo(unknown)")) + } + + return slog.GroupValue(attrs...) +} diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp index ab13669c5..f75b923fd 100644 --- a/ml/backend/ggml/ggml/src/gguf.cpp +++ b/ml/backend/ggml/ggml/src/gguf.cpp @@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); return ctx->kv[key_id].data.data(); } +size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].data.size(); +} + const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); @@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); - GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); return ctx->kv[key_id].data.data(); } diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 4f0c9fa14..a3f43a1ea 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -3,6 +3,7 @@ package nn import ( "fmt" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" ) @@ -11,40 +12,50 @@ import ( // // Parameters: // - ctx: Context for tensor operations -// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] -// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] -// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] -// - mask: Optional attention mask that is added to the attention score. If -// provided, should broadcast to [seq_len_k, seq_len_q, heads] +// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] +// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only +// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value // // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { - if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + if key != nil && value != nil { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if key.Dim(1) != value.Dim(1) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + if cache != nil { + cache.Put(ctx, key, value) + } + } else if cache == nil { + panic("key & value tensors must be provided if cache is nil") } - if mask != nil && query.Dim(1) != mask.Dim(1) { - panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) + var mask ml.Tensor + if cache != nil { + key, value, mask = cache.Get(ctx) } - if key.Dim(1) != value.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) - } - - if mask != nil && key.Dim(1) != mask.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0))) - } - - if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) - } - - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + // Only use the fast SDPA implementation if we have a cache, since that's what + // will do any expected backend-specific transformations for us + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) } else { + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) kq = kq.Scale(ctx, scale) diff --git a/model/input/input.go b/model/input/input.go new file mode 100644 index 000000000..0cb3f3f41 --- /dev/null +++ b/model/input/input.go @@ -0,0 +1,37 @@ +package input + +// Input represents one token in the input stream +type Input struct { + // Token is a single element of text. + Token int32 + + // Multimodal is opaque data representing a non-text + // element such as an image (or part of one if the image + // can be processed in pieces). It may be either together + // with Token or on its own. + Multimodal any + + // MultimodalHash is a unique representation of the data + // stored in Multimodal, used for caching and comparing + // equality. + MultimodalHash uint64 +} + +// MultimodalIndex is a multimodal element (such as an image) +// together with an index into the slice of Inputs with the +// corresponding token. Note that the index is not the same +// as the position - to find that use the index with the +// Positions slice. +type MultimodalIndex struct { + Index int + Multimodal any +} + +// Options contains the inputs for a model forward pass +type Options struct { + Inputs []int32 + Multimodal []MultimodalIndex + Positions []int32 + Sequences []int + Outputs []int32 +} diff --git a/model/model.go b/model/model.go index 16020b354..89b6c803b 100644 --- a/model/model.go +++ b/model/model.go @@ -3,7 +3,6 @@ package model import ( "errors" "fmt" - "image" _ "image/jpeg" _ "image/png" "log/slog" @@ -16,23 +15,50 @@ import ( _ "golang.org/x/image/tiff" _ "golang.org/x/image/webp" + fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/model/input" ) -// Options contains the inputs for a model forward pass -type Options struct { - Inputs []int32 - Positions []int32 - Sequences []int - Outputs []int32 +// Model implements a specific model architecture, defining the forward pass and any model-specific configuration +type Model interface { + Forward(ml.Context, input.Options) (ml.Tensor, error) - Images []image.Image + Backend() ml.Backend + Config() config } -type config struct { - Cache kvcache.Cache +// MultimodalProcessor must be implemented by multimodal models. +type MultimodalProcessor interface { + // EncodeMultimodal processes a single input (such as an image) and + // generates an output (typically an embedding) that can be used by the model. + // + // The return value is most typically an ml.Tensor, however, different + // type are possible, such as an object containing a tensor plus + // additional metadata, a slice of tensors or even just the original input. + // + // The result may be cached by the runner. + EncodeMultimodal(ml.Context, []byte) (any, error) + + // PostTokenize is called after tokenization to allow the model to edit the + // input stream to correctly arrange multimodal elements. + // + // The input is a slice of tokens with the results of EncodeMultimodal interleaved + // in the order that the user provided them. Each element of the slice will be + // either a single token or single multimodal object. + // + // The model must ensure that inputs are stored according to how they will be + // processed and stored in the cache. For example, Llava-style models should insert + // placeholder tokens equal to the feature size of the corresponding image with + // the image itself attached to and split across these tokens. When Forward is called + // a partial subset of these tokens may be submitted according to the batch size. + // + // This function is also responsible for updating MultimodalHash for any Multimodal + // that is modified to ensure that there is a unique hash value that accurately + // represents the contents. + PostTokenize(ml.Context, []input.Input) ([]input.Input, error) } // Base implements the common fields and methods for all models @@ -41,6 +67,10 @@ type Base struct { config } +type config struct { + Cache kvcache.Cache +} + // Backend returns the underlying backend that will run the model func (m *Base) Backend() ml.Backend { return m.b @@ -50,14 +80,6 @@ func (m *Base) Config() config { return m.config } -// Model implements a specific model architecture, defining the forward pass and any model-specific configuration -type Model interface { - Forward(ml.Context, Options) (ml.Tensor, error) - - Backend() ml.Backend - Config() config -} - var models = make(map[string]func(ml.Config) (Model, error)) // Register registers a model constructor for the given architecture @@ -100,6 +122,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { return m, nil } +func NewTextProcessor(s string) (TextProcessor, error) { + r, err := os.Open(s) + if err != nil { + return nil, err + } + defer r.Close() + meta, _, err := fs.Decode(r, -1) + if err != nil { + return nil, err + } + return getTextProcessor(meta.KV()) +} + +func getTextProcessor(kv fs.KV) (TextProcessor, error) { + arch := kv.Architecture() + f, ok := models[arch] + if !ok { + return nil, fmt.Errorf("unsupported model architecture %q", arch) + } + m, err := f(kv) + if err != nil { + return nil, err + } + tp, ok := m.(TextProcessor) + if !ok { + return nil, fmt.Errorf("%v is not a TextProcessor", m) + } + return tp, nil +} + func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() @@ -226,7 +278,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { if len(opts.Positions) != len(opts.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) } @@ -237,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, opts.Positions, opts.Sequences) + err := cache.StartForward(ctx, opts) if err != nil { return nil, err } diff --git a/model/model_test.go b/model/model_test.go index 02b8aa3c2..354dd1d8b 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -3,12 +3,15 @@ package model import ( "reflect" "slices" + "strings" "testing" "github.com/google/go-cmp/cmp" + fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -134,3 +137,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) { t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) } } + +func TestGetTextProcessor(t *testing.T) { + tp, err := getTextProcessor(fs.KV{}) + if err == nil { + t.Error("expected error") + } else if !strings.Contains(err.Error(), "unsupported model architecture") { + t.Errorf("unexpected error: %v", err) + } else if tp != nil { + t.Error("expected nil tp") + } + + models["dummy"] = func(ml.Config) (Model, error) { + return notTextProcessorModel{}, nil + } + tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"}) + if err == nil { + t.Error("expected error") + } else if !strings.Contains(err.Error(), "not a TextProcessor") { + t.Errorf("unexpected error: %v", err) + } else if tp != nil { + t.Error("expected nil tp") + } +} + +type notTextProcessorModel struct{} + +func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { + panic("unimplemented") +} + +func (notTextProcessorModel) Backend() ml.Backend { + panic("unimplemented") +} + +func (notTextProcessorModel) Config() config { + panic("unimplemented") +} diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6106af867..1f27f522d 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -1,16 +1,18 @@ package llama import ( + "fmt" "math" + "strings" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Options struct { - RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` hiddenSize, numHeads, numKVHeads int eps, ropeBase, ropeScale float32 ropeDim uint32 @@ -29,6 +31,10 @@ type Model struct { } func New(c ml.Config) (model.Model, error) { + if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { + return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) + } + m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), @@ -60,10 +66,11 @@ func New(c ml.Config) (model.Model, error) { } type SelfAttention struct { - Query *nn.Linear `gguf:"attn_q"` - Key *nn.Linear `gguf:"attn_k"` - Value *nn.Linear `gguf:"attn_v"` - Output *nn.Linear `gguf:"attn_output"` + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { @@ -72,31 +79,24 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) - - q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { @@ -138,18 +138,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { - inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { + inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) + positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 9b35a2628..31ba15dfd 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -1,10 +1,18 @@ package mllama import ( + "bytes" + "encoding/binary" + "fmt" + "hash/fnv" + "image" + "slices" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Model struct { @@ -25,6 +33,10 @@ const ( ) func New(c ml.Config) (model.Model, error) { + // Verify unified config + if c.Uint("vision.block_count") == 0 { + return nil, fmt.Errorf("non-unified vision model not supported") + } m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), @@ -43,59 +55,99 @@ func New(c ml.Config) (model.Model, error) { TextModel: newTextModel(c), } - m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) + encoderCache := kvcache.NewEncoderCache() + encoderCache.SetConfig(ml.CacheConfig{}) + m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift)) return &m, nil } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, err + } + + pixelValues, err := ctx.Input().FromFloatSlice(f32s, + m.ImageProcessor.imageSize, + m.ImageProcessor.imageSize, + m.ImageProcessor.numChannels, + m.ImageProcessor.maxNumTiles, + ) + if err != nil { + return nil, err + } + + aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1) + if err != nil { + return nil, err + } + + positions := make([]int32, 1601) + for i := range positions { + positions[i] = int32(i) + } + + positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) + if err != nil { + return nil, err + } + + crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) + return m.Projector.Forward(ctx, crossAttentionStates), nil +} + +func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { + var images []input.Input + fnvHash := fnv.New64a() + + for i := range inputs { + if inputs[i].Multimodal == nil { + if len(images) > 0 { + inputs[i].Multimodal = images[0].Multimodal + inputs[i].MultimodalHash = images[0].MultimodalHash + for j := 1; j < len(images); j++ { + inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) + fnvHash.Reset() + binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) + binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) + inputs[i].MultimodalHash = fnvHash.Sum64() + } + images = nil + } + } else { + images = append(images, inputs[i]) + inputs[i].Token = -1 + } + } + + inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 }) + + return inputs, nil +} + +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor - if opts.Images != nil { - f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0]) - if err != nil { - return nil, err - } - - pixelValues, err := ctx.FromFloatSlice(f32s, - m.ImageProcessor.imageSize, - m.ImageProcessor.imageSize, - m.ImageProcessor.numChannels, - m.ImageProcessor.maxNumTiles, - ) - if err != nil { - return nil, err - } - - aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1) - if err != nil { - return nil, err - } - - positions := make([]int32, 1601) - for i := range positions { - positions[i] = int32(i) - } - - positionIDs, err := ctx.FromIntSlice(positions, len(positions)) - if err != nil { - return nil, err - } - - crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) - crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates) + if len(opts.Multimodal) > 0 { + crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) } - inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) + inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) + positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 003bf9cbf..373589f9e 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -10,10 +10,11 @@ import ( ) type TextSelfAttention struct { - Query *nn.Linear `gguf:"attn_q"` - Key *nn.Linear `gguf:"attn_k"` - Value *nn.Linear `gguf:"attn_v"` - Output *nn.Linear `gguf:"attn_output"` + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` } func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { @@ -22,32 +23,28 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, key, value) - key, value, mask := cache.Get(ctx) - - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - // This will only get called for layers in the cache, which are just the self attention layers - return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil + if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { + return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil + } + + return key, nil } type TextMLP struct { @@ -107,7 +104,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - var key, value, mask ml.Tensor + var key, value ml.Tensor if crossAttentionStates != nil { numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) @@ -119,16 +116,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) cache.Put(ctx, key, value) - } else { - key, value, mask = cache.Get(ctx) } - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + key, value, _ = cache.Get(ctx) scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + kq := key.MulmatFullPrec(ctx, query) + + kq = kq.Scale(ctx, scaleFactor) + kq = kq.Softmax(ctx) + + kqv := value.Mulmat(ctx, kq) + attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention) @@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, } type TextModelOptions struct { - RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` - hiddenSize, numHeads, numKVHeads int eps, ropeBase, ropeScale float32 ropeDim uint32 diff --git a/model/process_text.go b/model/process_text.go index 7083f36fd..0d75a0ed0 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -19,7 +19,7 @@ const ( ) type TextProcessor interface { - Encode(string) ([]int32, error) + Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool } @@ -144,7 +144,7 @@ type merge struct { runes []rune } -func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { +func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range bpe.vocab.SpecialVocabulary() { // TODO: process special tokens concurrently @@ -177,7 +177,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { for _, frag := range fragments { if len(frag.ids) > 0 { ids = append(ids, frag.ids...) - slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true) continue } @@ -201,7 +200,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { // short circuit if the fragment is in the vocabulary if id := bpe.vocab.Encode(sb.String()); id >= 0 { ids = append(ids, id) - slog.Debug("encoded", "text", sb.String(), "ids", []int32{id}) continue } @@ -275,14 +273,13 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { // TODO: handle the edge case where the rune isn't in the vocabulary if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { ids = append(ids, id) - slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id}) } } } } } - if len(ids) > 0 { + if addSpecial && len(ids) > 0 { if bpe.vocab.AddBOS { if ids[0] == bpe.vocab.BOS { slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) @@ -329,6 +326,5 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - slog.Debug("decoded", "ids", ids, "text", sb.String()) return sb.String(), nil } diff --git a/model/process_text_test.go b/model/process_text_test.go index cad1f94ff..f48303212 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -74,7 +74,7 @@ func TestLlama(t *testing.T) { t.Run("simple", func(t *testing.T) { t.Parallel() - ids, err := tokenizer.Encode("hello world") + ids, err := tokenizer.Encode("hello world", true) if err != nil { t.Error(err) } @@ -92,7 +92,7 @@ func TestLlama(t *testing.T) { t.Errorf("got %q, want hello world", s) } - ids, err = tokenizer.Encode("hello <|end_of_text|>") + ids, err = tokenizer.Encode("hello <|end_of_text|>", true) if err != nil { t.Error(err) } @@ -126,7 +126,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Error(err) } @@ -152,7 +152,7 @@ func TestLlama(t *testing.T) { } for _, want := range cases { - ids, err := tokenizer.Encode(want) + ids, err := tokenizer.Encode(want, true) if err != nil { t.Error(err) } @@ -176,7 +176,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() for range b.N { - _, err := tokenizer.Encode(string(bts)) + _, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } @@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { - ids, err := tokenizer.Encode(string(bts)) + ids, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 82880c980..8662afc1e 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -931,7 +931,6 @@ func Execute(args []string) error { slog.Info("starting go runner") llama.BackendInit() - slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads) server := &Server{ batchSize: *batchSize, diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index e1fa98b1a..a411fddb1 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -5,12 +5,12 @@ import ( "fmt" "log/slog" "math" - "reflect" "time" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type InputCache struct { @@ -39,10 +39,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots slots := make([]InputCacheSlot, numSlots) for i := range slots { - slots[i] = InputCacheSlot{ - Id: i, - Inputs: make([]input, 0), - } + slots[i] = InputCacheSlot{Id: i} } cache := model.Config().Cache @@ -62,9 +59,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots func kvCacheTypeFromStr(s string) ml.DType { switch s { case "q8_0": - panic("kv cache quantization not yet implemented") + return ml.DTypeQ80 case "q4_0": - panic("kv cache quantization not yet implemented") + return ml.DTypeQ40 default: return ml.DTypeF16 } @@ -83,7 +80,7 @@ type InputCacheSlot struct { Id int // Inputs that are stored in the KV cache - Inputs []input + Inputs []input.Input // is this cache actively being processed as part of a sequence? InUse bool @@ -92,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -143,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach return slot, prompt, nil } -func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { longest := int32(-1) var longestSlot *InputCacheSlot @@ -166,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3 return longestSlot, longest, nil } -func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { oldest := time.Now() var oldestSlot *InputCacheSlot @@ -202,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, if longest > 0 && longestSlot != oldestSlot { slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", len(longestSlot.Inputs)) - oldestSlot.Inputs = make([]input, longest) + oldestSlot.Inputs = make([]input.Input, longest) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) if c.cache != nil { c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) @@ -212,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, return oldestSlot, longest, nil } -func countCommonPrefix(a []input, b []input) int32 { +func countCommonPrefix(a []input.Input, b []input.Input) int32 { var count int32 for i := range a { @@ -220,7 +217,7 @@ func countCommonPrefix(a []input, b []input) int32 { break } - if !reflect.DeepEqual(a[i], b[i]) { + if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash { break } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 99e67b4fa..0a1b73f5a 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -4,6 +4,8 @@ import ( "image" "testing" "time" + + "github.com/ollama/ollama/model/input" ) func TestCountCommon(t *testing.T) { @@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) { tests := []struct { name string - t1 []input - t2 []input + t1 []input.Input + t2 []input.Input expected int32 }{ { name: "Equal", - t1: []input{{token: 1}, {token: 2}, {token: 3}}, - t2: []input{{token: 1}, {token: 2}, {token: 3}}, + t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 3, }, { name: "Prefix", - t1: []input{{token: 1}}, - t2: []input{{token: 1}, {token: 2}, {token: 3}}, + t1: []input.Input{{Token: 1}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 1, }, { name: "Image Prefix", - t1: []input{{image: imgA}}, - t2: []input{{image: imgA}, {image: imgB}, {image: imgC}}, + t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []input{{token: 1}, {image: imgA}}, - t2: []input{{token: 1}, {image: imgA}, {token: 5}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, expected: 2, }, + { + name: "Mixed, Same Length", + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, + expected: 1, + }, { name: "Empty", - t1: []input{}, - t2: []input{{token: 1}, {token: 2}, {token: 3}}, + t1: []input.Input{}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 0, }, { name: "Both Empty", - t1: []input{}, - t2: []input{}, + t1: []input.Input{}, + t2: []input.Input{}, expected: 0, }, } @@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []input + prompt []input.Input longest expected best expected }{ @@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, { Id: 1, - Inputs: []input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []input{{token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 0}, best: expected{result: 0, len: 0}, }, @@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input{{token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input{{token: 1}, {token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input{{token: 1}, {token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 2}, best: expected{result: 1, len: 2}, }, @@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input{{token: 1}, {token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []input{{token: 2}}, + prompt: []input.Input{{Token: 2}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input{{token: 1}, {token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }, }, - prompt: []input{{token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 1}, best: expected{result: 1, len: 1}, }, @@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input{{token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input{{token: 1}, {token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input{{token: 2}, {token: 3}}, + prompt: []input.Input{{Token: 2}, {Token: 3}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input{{token: 1}, {token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input{{token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input{{token: 1}, {token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 1}, best: expected{result: 1, len: 2}, }, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index db9b271e5..c1475cbb2 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1,13 +1,12 @@ package ollamarunner import ( - "bytes" "context" "encoding/json" "errors" "flag" "fmt" - "image" + "hash/maphash" "log" "log/slog" "net" @@ -27,28 +26,26 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" _ "github.com/ollama/ollama/model/models" ) -// input is an element of the prompt to process, either a token or an image -type input struct { - token int32 - - image image.Image -} - type Sequence struct { + // ctx for allocating tensors that last the lifetime of the sequence, such as + // multimodal embeddings + ctx ml.Context + // batch index iBatch int // prompt inputs left to evaluate - inputs []input + inputs []input.Input // inputs that have been added to a batch but not yet submitted to Forward - pendingInputs []input + pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -101,8 +98,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen s.ready.Wait() startTime := time.Now() + ctx := s.model.Backend().NewContext() - inputs, err := s.inputs(prompt, images) + inputs, err := s.inputs(ctx, prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -128,6 +126,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // TODO(jessegross): Ingest cached history for grammar return &Sequence{ + ctx: ctx, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -146,28 +145,31 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { - var inputs []input +func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { + var inputs []input.Input var parts []string var matches [][]string - // TODO(jessegross): This can sometimes trigger for matching text in the - // user's prompt. We previously tried to avoid it by only looking for images - // on image models. We don't have a clear indication now but it would be better - // to properly escape it in any case. - re := regexp.MustCompile(`\[img-(\d+)\]`) - parts = re.Split(prompt, -1) - matches = re.FindAllStringSubmatch(prompt, -1) + multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor) + if visionModel { + re := regexp.MustCompile(`\[img-(\d+)\]`) + parts = re.Split(prompt, -1) + matches = re.FindAllStringSubmatch(prompt, -1) + } else { + parts = []string{prompt} + } + + postTokenize := false for i, part := range parts { // text - tokenize - tokens, err := s.model.(model.TextProcessor).Encode(part) + tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { return nil, err } for _, t := range tokens { - inputs = append(inputs, input{token: t}) + inputs = append(inputs, input.Input{Token: t}) } // image - decode and store @@ -186,12 +188,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { return nil, fmt.Errorf("invalid image index: %d", n) } - image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data)) + imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { return nil, err } - inputs = append(inputs, input{image: image}) + s.multimodalHash.Reset() + _, _ = s.multimodalHash.Write(images[imageIndex].Data) + imageHash := s.multimodalHash.Sum64() + + inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) + postTokenize = true + } + } + + if visionModel && postTokenize { + var err error + inputs, err = multimodalProcessor.PostTokenize(ctx, inputs) + if err != nil { + return nil, err } } @@ -236,8 +251,15 @@ type Server struct { // KV cache cache *InputCache - // next sequence for prompt processing to avoid starvation - nextSeq int + // multimodalHash generates hashes for comparing equality + // of non-text data + multimodalHash maphash.Hash + + // vocab is a llama.cpp vocab required for gammar-based + // constrained generation (json mode, structured outputs) + // TODO: this is temporary until Ollama sampling supports + // constrained generation + vocab *sample.Vocab } func (s *Server) allNil() bool { @@ -283,6 +305,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.responses) close(seq.embedding) seq.cache.InUse = false + seq.ctx.Close() s.seqs[seqIndex] = nil s.seqsSem.Release(1) } @@ -310,30 +333,25 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() - var options model.Options - imgSeq := -1 - - seqIdx := s.nextSeq - 1 - for range s.seqs { - seqIdx = (seqIdx + 1) % len(s.seqs) - seq := s.seqs[seqIdx] + var options input.Options + for i, seq := range s.seqs { if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(i, "limit") continue } if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) - seq.cache.Inputs = []input{} + seq.cache.Inputs = []input.Input{} } - for i, input := range seq.inputs { + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) @@ -345,37 +363,23 @@ func (s *Server) processBatch() error { } } - if i >= s.batchSize { + if j >= s.batchSize { break } - // TODO(jessegross): Image inputs need to be rethought - it's - // it doesn't work well for different types of models or multiple sequences - if input.image != nil { - if len(seq.pendingInputs) != len(options.Images) { - break - } - - if imgSeq != seqIdx && imgSeq != -1 { - s.nextSeq = seqIdx - break - } - - imgSeq = seqIdx - options.Images = append(options.Images, input.image) - seq.pendingInputs = append(seq.pendingInputs, input) - continue + options.Inputs = append(options.Inputs, inp.Token) + if inp.Multimodal != nil { + options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) } - options.Inputs = append(options.Inputs, input.token) options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Sequences = append(options.Sequences, seq.cache.Id) seq.iBatch = len(options.Outputs) - if i+1 == len(seq.inputs) { + if j+1 == len(seq.inputs) { options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) } - seq.pendingInputs = append(seq.pendingInputs, input) + seq.pendingInputs = append(seq.pendingInputs, inp) } seq.inputs = seq.inputs[len(seq.pendingInputs):] @@ -403,7 +407,7 @@ func (s *Server) processBatch() error { // After calling Forward, pending inputs are now in the cache if len(seq.pendingInputs) > 0 { seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) - seq.pendingInputs = []input{} + seq.pendingInputs = []input.Input{} } // don't sample prompt processing @@ -422,6 +426,7 @@ func (s *Server) processBatch() error { // if done processing the prompt, generate an embedding and return if seq.embeddingOnly { // TODO(jessegross): Embedding support + slog.Warn("generation of embedding outputs not yet supported") s.removeSequence(i, "") continue } @@ -449,7 +454,7 @@ func (s *Server) processBatch() error { return err } - seq.inputs = []input{{token: token}} + seq.inputs = []input.Input{{Token: token}} seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") @@ -575,11 +580,30 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + var grammar *sample.Grammar + var err error + if req.Grammar != "" { + grammar, err = sample.NewGrammar(s.vocab, req.Grammar) + if err != nil { + http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError) + return + } + } + + sampler := sample.NewSampler( + req.Temperature, + req.TopK, + req.TopP, + req.MinP, + req.Seed, + grammar, + ) + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, numKeep: int32(req.NumKeep), - sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized + sampler: sampler, embedding: false, }) if err != nil { @@ -786,7 +810,7 @@ func (s *Server) loadModel( panic(err) } - slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads) + s.vocab = sample.NewVocab(mpath) // TODO(jessegross): LoRA loading if lpath.String() != "" { @@ -818,7 +842,7 @@ func Execute(args []string) error { batchSize := fs.Int("batch-size", 512, "Batch size") numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") mainGPU := fs.Int("main-gpu", 0, "Main GPU") - _ = fs.Bool("flash-attn", false, "Enable flash attention") + flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") @@ -863,7 +887,6 @@ func Execute(args []string) error { } // TODO(jessegross): Parameters that need to be implemented: - // flash-attn // no-mmap // mlock @@ -878,10 +901,11 @@ func Execute(args []string) error { } params := ml.BackendParams{ - NumThreads: *threads, - NumGPULayers: *numGPULayers, - MainGPU: *mainGPU, - TensorSplit: tensorSplitFloats, + NumThreads: *threads, + NumGPULayers: *numGPULayers, + MainGPU: *mainGPU, + TensorSplit: tensorSplitFloats, + FlashAttention: *flashAttention, } server.ready.Add(1) diff --git a/sample/samplers.go b/sample/samplers.go index 1b8a5edd9..aea99b3f2 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -3,118 +3,226 @@ package sample import ( "errors" "math" + "math/rand/v2" + "slices" + "sync" - "golang.org/x/exp/rand" - "gonum.org/v1/gonum/stat/sampleuv" + "github.com/ollama/ollama/llama" ) -type Sampler interface { - Sample([]float32) (int32, error) +// token represents information about a single token during sampling +type token struct { + id int32 // The token's unique identifier + value float32 // The raw logit or probability from the model } -type weighted struct { - src rand.Source - transforms []Transform +type Sampler struct { + rng *rand.Rand + topK int + topP float32 + minP float32 + temperature float32 + grammar *Grammar } -// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279 -func Weighted(seed *uint64, transforms ...Transform) Sampler { - var src rand.Source - if seed != nil { - src = rand.NewSource(*seed) - } - return weighted{src: src, transforms: transforms} -} - -func (s weighted) Sample(logits []float32) (int32, error) { - logits64 := make([]float64, len(logits)) - for i, v := range logits { - logits64[i] = float64(v) - } - - for _, t := range s.transforms { - logits64 = t.Apply(logits64) - } - - logitsCopy := make([]float64, 0, len(logits)) - indices := make([]int, 0, len(logits)) - for i, logit := range logits64 { - if !math.IsInf(logit, -1) { - logitsCopy = append(logitsCopy, logit) - indices = append(indices, i) - } - } - - if len(logitsCopy) == 0 { - return -1, errors.New("no valid logits found for weighed sampling") - } - - probs := softmax(logitsCopy) - w := sampleuv.NewWeighted(probs, s.src) - if idx, ok := w.Take(); ok { - return int32(indices[idx]), nil - } - return -1, errors.New("weighted sampler failed, no valid token found") -} - -type greedy struct{} - -func Greedy() Sampler { - return greedy{} -} - -// Sample returns the index of the maximum value in logits. -func (s greedy) Sample(logits []float32) (int32, error) { - if len(logits) == 0 { - return -1, errors.New("no logits provided for greedy sampling") - } - - maxIdx := 0 +func (s *Sampler) Sample(logits []float32) (int32, error) { + tokens := make([]token, len(logits)) for i := range logits { - if logits[i] > logits[maxIdx] { - maxIdx = i + tokens[i].id = int32(i) + tokens[i].value = logits[i] + } + + t, err := s.sample(tokens) + if err != nil { + return -1, err + } + + if s.grammar != nil { + // optimization: first check if the max logit is accepted by the grammar + // if the max logit is rejected, apply the grammar to all logits (slower) + top := []token{t} + s.grammar.Apply(top) + if !math.IsInf(float64(top[0].value), -1) { + s.grammar.Accept(top[0].id) + return top[0].id, nil + } + + // since .sample has side effects of modifying the tokens + // we need to reset them before applying the grammar and + // sampling again + for i := range logits { + tokens[i].id = int32(i) + tokens[i].value = logits[i] + } + s.grammar.Apply(tokens) + t, err = s.sample(tokens) + if err != nil { + return -1, err + } + s.grammar.Accept(t.id) + } + + return t.id, nil +} + +// greedy returns the highest probability token from the tokens +func greedy(tokens []token) token { + max := tokens[0] + for i := 1; i < len(tokens); i++ { + if tokens[i].value > max.value { + max = tokens[i] } } - return int32(maxIdx), nil + return max +} + +// sample returns the highest probability token from the tokens +// given sampler parameters. It also has side effects of modifying the tokens +func (s *Sampler) sample(tokens []token) (token, error) { + if s.temperature == 0 { + return greedy(tokens), nil + } + + if s.topK > 0 { + tokens = topK(tokens, s.topK) + } else { + sortLogits(tokens) + } + + // token logit values are updated to probabilities + tokens = temperature(tokens, s.temperature) + + tokens = topP(tokens, s.topP) + tokens = minP(tokens, s.minP) + + // TODO: this should fall back to greedy sampling + // or topP, topK values etc should be such that + // there are always tokens to sample from + if len(tokens) == 0 { + return token{}, errors.New("no tokens to sample from") + } + + var r float32 + if s.rng != nil { + r = s.rng.Float32() + } else { + r = rand.Float32() + } + + // Calculate cumulative sum of probabilities + var sum float32 + for i := range tokens { + sum += tokens[i].value + tokens[i].value = sum + } + r *= tokens[len(tokens)-1].value + + idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int { + if token.value < target { + return -1 + } + return 1 + }) + + return tokens[idx], nil } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 -func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) { - if temperature == 0 { - return Greedy(), nil +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { + var rng *rand.Rand + if seed != -1 { + // PCG requires two parameters: sequence and stream + // Use original seed for sequence + sequence := uint64(seed) + // Use golden ratio hash to generate statistically independent seeds + rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9)) + } + if temperature < 0.0 { + temperature = 0.0 } - if temperature < 0 || temperature > 2 { - return nil, errors.New("temperature must be between 0 and 2") + if topP < 0.0 { + topP = 0.0 + } + if topP >= 1.0 { + topP = 1.0 } - transforms := []Transform{Temperature(temperature)} - - if topK != 0 { - if topK <= 0 { - return nil, errors.New("topK must be greater than 0") - } - transforms = append(transforms, TopK(topK)) + if minP < 0.0 { + minP = 0.0 + } + if minP >= 1.0 { + minP = 1.0 } - if topP != 0 { - if topP < 0 || topP >= 1 { - return nil, errors.New("topP must be between 0 and 1") - } - transforms = append(transforms, TopP(topP)) + return Sampler{ + rng: rng, + topK: topK, + topP: topP, + minP: minP, + temperature: temperature, + grammar: grammar, } - - if minP != 0 { - if minP < 0 || minP >= 1 { - return nil, errors.New("minP must be between 0 and 1") - } - transforms = append(transforms, MinP(minP)) - } - - if seed >= 0 { - seed64 := uint64(seed) - return Weighted(&seed64, transforms...), nil - } - return Weighted(nil, transforms...), nil +} + +type Grammar struct { + vocab *Vocab + grammar string + sampler *llama.Sampler +} + +func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) { + v, err := vocab.Load() + if err != nil { + return nil, err + } + + return &Grammar{ + vocab: vocab, + grammar: grammar, + sampler: llama.NewGrammarSampler(v, grammar), + }, nil +} + +func (g *Grammar) Apply(tokens []token) { + tds := make([]llama.TokenData, len(tokens)) + for i, token := range tokens { + tds[i].Id = token.id + tds[i].Logit = token.value + } + + g.sampler.Apply(tds) + + for i := range tokens { + tokens[i].value = tds[i].Logit + } +} + +func (g *Grammar) Accept(token int32) { + g.sampler.Accept(token) +} + +type Vocab struct { + once sync.Once + vocab *llama.Vocab + err error + path string +} + +func NewVocab(path string) *Vocab { + return &Vocab{path: path} +} + +// Load returns the lazily-loaded vocabulary +func (v *Vocab) Load() (*llama.Vocab, error) { + v.once.Do(func() { + vocab, err := llama.LoadVocabFromFile(v.path) + if err != nil { + v.err = err + return + } + v.vocab = vocab + }) + return v.vocab, v.err } diff --git a/sample/samplers_benchmark_test.go b/sample/samplers_benchmark_test.go new file mode 100644 index 000000000..cd1380141 --- /dev/null +++ b/sample/samplers_benchmark_test.go @@ -0,0 +1,92 @@ +package sample + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkWeightedSampler(b *testing.B) { + sizes := []int{10, 100, 1000, 10000} + + for _, size := range sizes { + b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) { + logits := make([]float32, size) + for i := range logits { + logits[i] = float32(rand.Float64()*10 - 5) + } + + sampler := NewSampler(0.8, 0, 0, 0, 42, nil) + b.ResetTimer() + for b.Loop() { + sampler.Sample(logits) + } + }) + } + + configs := []struct { + name string + temperature float32 + topK int + topP float32 + minP float32 + seed int + }{ + {"Greedy", 0, -1, 0, 0, -1}, + {"Temperature", 0.8, -1, 0, 0, -1}, + {"TopK", 0.8, 50, 0, 0, -1}, + {"TopP", 0.8, -1, 0.9, 0, -1}, + {"MinP", 0.8, -1, 0, 0.05, -1}, + {"WithSeed", 0.8, 50, 0, 0, 42}, + } + + // Fixed size for common vocab size + size := 128000 + logits := make([]float32, size) + for i := range logits { + logits[i] = float32(rand.Float64()*10 - 5) + } + + for _, tc := range configs { + b.Run("Config"+tc.name, func(b *testing.B) { + sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil) + sampler.Sample(logits) + + b.ResetTimer() + + for b.Loop() { + sampler.Sample(logits) + } + }) + } + + // Test with combined transforms separately - topK influences performance greatly + b.Run("TransformCombined", func(b *testing.B) { + sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil) + b.ResetTimer() + + for b.Loop() { + sampler.Sample(logits) + } + }) +} + +func BenchmarkGreedySampler(b *testing.B) { + sizes := []int{10, 100, 1000, 10000, 100000} + + for _, size := range sizes { + b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) { + logits := make([]float32, size) + for i := range logits { + logits[i] = float32(rand.Float64()*10 - 5) + } + + sampler := NewSampler(0, -1, 0, 0, -1, nil) + b.ResetTimer() + + for b.Loop() { + sampler.Sample(logits) + } + }) + } +} diff --git a/sample/samplers_test.go b/sample/samplers_test.go index 32364a3b7..38b9b352a 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -1,15 +1,14 @@ package sample import ( - "math" "math/rand/v2" "testing" - - "github.com/google/go-cmp/cmp" ) func TestWeighted(t *testing.T) { - got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))}) + logits := []float32{-10, 3, -10, -10} + sampler := NewSampler(0, 0, 0, 0, 0, nil) + got, err := sampler.Sample(logits) if err != nil { t.Error(err) return @@ -19,194 +18,26 @@ func TestWeighted(t *testing.T) { t.Errorf("index mismatch: want %d, got %d", want, got) } - got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))}) - if err == nil { - t.Error("expected error for no valid tokens, got index", got) - } - - seed := uint64(42) - got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4}) + logits = []float32{-100, -10, 0, 10} + sampler = NewSampler(0, 0, 0, 0, 0, nil) + got, err = sampler.Sample(logits) if err != nil { t.Error(err) return } - // With seed 42, we expect a consistent sample - want = int32(3) // This will be deterministic due to the seed + want = int32(3) // Should pick highest probability with this r value if want != got { t.Errorf("index mismatch: want %d, got %d", want, got) } } -type testTransform struct { - id int - callOrder *[]int -} - -func (ts *testTransform) Apply(logits []float64) []float64 { - if ts.callOrder != nil { - *ts.callOrder = append(*ts.callOrder, ts.id) - } - return logits -} - -func TestSample(t *testing.T) { - input := []float32{1, 2, 3, 4} - - var callOrder []int - mock1 := &testTransform{ - id: 1, - callOrder: &callOrder, - } - mock2 := &testTransform{ - id: 2, - callOrder: &callOrder, - } - mock3 := &testTransform{ - id: 3, - callOrder: &callOrder, - } - - _, err := Weighted(nil, mock1, mock2, mock3).Sample(input) - if err != nil { - t.Error(err) - return - } - wantOrder := []int{1, 2, 3} - if diff := cmp.Diff(wantOrder, callOrder); diff != "" { - t.Errorf("call order mismatch (-want +got):\n%s", diff) - } -} - -func TestNewSampler(t *testing.T) { - tests := []struct { - name string - temperature float32 - topK int - topP float32 - minP float32 - seed int - wantErr bool - }{ - { - name: "no transforms", - // temperature is 0, so greedy should be used - wantErr: false, - }, - { - name: "temperature", - temperature: 0.5, - wantErr: false, - }, - { - name: "invalid temperature negative", - temperature: -1, - wantErr: true, - }, - { - name: "invalid temperature too high", - temperature: 2.1, - wantErr: true, - }, - { - name: "top k", - topK: 10, - temperature: 0.8, - wantErr: false, - }, - { - name: "invalid top k negative", - topK: -1, - temperature: 0.8, - wantErr: true, - }, - { - name: "top p", - topP: 0.9, - temperature: 0.8, - wantErr: false, - }, - { - name: "invalid top p negative", - topP: -0.1, - temperature: 0.8, - wantErr: true, - }, - { - name: "invalid top p one", - topP: 1.0, - temperature: 0.8, - wantErr: true, - }, - { - name: "min p", - minP: 0.2, - temperature: 0.8, - wantErr: false, - }, - { - name: "invalid min p negative", - minP: -0.1, - temperature: 0.8, - wantErr: true, - }, - { - name: "invalid min p one", - minP: 1.0, - temperature: 0.8, - wantErr: true, - }, - { - name: "default values", - temperature: 0.8, - topK: 40, - topP: 0.9, - minP: 0.0, - seed: 0, - wantErr: false, - }, - { - name: "all zeroes", - temperature: 0.0, - topK: 0, - topP: 0.0, - minP: 0.0, - seed: 0, - wantErr: false, // all zeroes means no transforms - }, - { - name: "all transforms", - temperature: 0.8, - topK: 50, - topP: 0.95, - minP: 0.1, - seed: 42, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) - if (err != nil) != tt.wantErr { - t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - func BenchmarkSample(b *testing.B) { - transforms := []Transform{ - Temperature(0.5), - TopK(10), - TopP(0.9), - MinP(0.2), - } - samplers := map[string]Sampler{ - "Greedy": Greedy(), - "Weighted": Weighted(nil, transforms...), + "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy + "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), } + // Generate random logits for benchmarking logits := make([]float32, 1<<16) for i := range logits { logits[i] = rand.Float32() @@ -215,9 +46,9 @@ func BenchmarkSample(b *testing.B) { for name, s := range samplers { b.Run(name, func(b *testing.B) { b.ResetTimer() - for range b.N { + for b.Loop() { if _, err := s.Sample(logits); err != nil { - b.Error(err) + b.Fatalf("error sampling: %v", err) } } }) diff --git a/sample/transforms.go b/sample/transforms.go index 2dc6ebae1..ab62455f3 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -1,120 +1,195 @@ package sample import ( - "cmp" "math" "slices" - - pq "github.com/emirpasic/gods/v2/queues/priorityqueue" ) -type Transform interface { - Apply([]float64) []float64 -} - -// TODO(parthsareen): potentially cache softmax values -func softmax(logits []float64) []float64 { - var sum float64 - probs := make([]float64, len(logits)) - for i, v := range logits { - probs[i] = math.Exp(v) - sum += probs[i] - } - - for i := range probs { - probs[i] /= sum - } - - return probs -} - -type Temperature float64 - -func (t Temperature) Apply(logits []float64) []float64 { - temp := math.Max(float64(t), 1e-7) - - // subtracting max logit to avoid under/overflow - maxLogit := slices.Max(logits) - for i := range logits { - logits[i] = (logits[i] - maxLogit) / temp - } - - return logits -} - -type logitMap struct { - index int - logit float64 -} - -type TopK int - -// TODO(parthsareen): avoid having to check all logits after this transform -func (k TopK) Apply(logits []float64) []float64 { - if int(k) >= len(logits) { - return logits - } - q := pq.NewWith(func(a, b logitMap) int { - return -cmp.Compare(a.logit, b.logit) - }) - - for i, logit := range logits { - q.Enqueue(logitMap{index: i, logit: logit}) - } - - validLogits := make(map[int]float64) - for range k { - logitMap, _ := q.Dequeue() - validLogits[logitMap.index] = logitMap.logit - } - - for i := range logits { - if _, ok := validLogits[i]; !ok { - logits[i] = math.Inf(-1) +// temperature applies scaling and softmax to the logits +func temperature(ts []token, temp float32) []token { + // Find max logit for numerical stability + maxLogit := float32(math.Inf(-1)) + for _, t := range ts { + if t.value > maxLogit { + maxLogit = t.value } } - return logits -} - -type TopP float64 - -func (p TopP) Apply(logits []float64) []float64 { - probs := softmax(logits) - indices := make([]int, len(probs)) - for i := range indices { - indices[i] = i + // Apply temperature and compute exp(x - max) + temp = max(temp, 1e-7) + var sum float32 + for i, v := range ts { + ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp))) + sum += ts[i].value } - // sort in descending order - slices.SortFunc(indices, func(i, j int) int { - return cmp.Compare(probs[j], probs[i]) - }) + // Normalize + for i := range ts { + ts[i].value /= sum + } - var sum float64 - for i, idx := range indices { - sum += probs[idx] - if sum > float64(p) { - for _, idx := range indices[i+1:] { - logits[idx] = math.Inf(-1) - } + return ts +} + +// siftDown maintains a min-heap property by recursively moving larger elements down the heap. +// +// The heap is represented as an array where for any node at index i: +// - Left child is at index 2i + 1 +// - Right child is at index 2i + 2 +// - Parent is at index (i-1)/2 +// +// The function compares a node with its children and: +// 1. Finds the smallest value between the node and its children +// 2. If the node is not the smallest, swaps it with its smallest child +// 3. Continues this process down the affected path until the min-heap property is restored +func siftDown(data []token, start, end int) { + root := start + for { + child := 2*root + 1 + if child >= end { break } + // Find smaller child (we want min heap) + if child+1 < end && data[child+1].value < data[child].value { + child++ + } + // Exit if root is already smaller than children + if data[root].value <= data[child].value { + break + } + // Swap with smaller child and continue + data[root], data[child] = data[child], data[root] + root = child } - return logits } -type MinP float64 +// topK limits the number of tokens considered to the k highest logits +func topK(ts []token, k int) []token { + if k >= len(ts) { + return ts + } + // Heapify + siftDown - O(nlog(k)) + // Build min-heap of first k elements + heap := ts[:k] + for i := k/2 - 1; i >= 0; i-- { + siftDown(heap, i, k) + } -func (p MinP) Apply(logits []float64) []float64 { - probs := softmax(logits) - threshold := slices.Max(probs) * float64(p) - - for i, prob := range probs { - if prob < threshold { - logits[i] = math.Inf(-1) + // Process remaining elements - if larger than heap root, replace root + for i := k; i < len(ts); i++ { + if ts[i].value > heap[0].value { + heap[0] = ts[i] + siftDown(heap, 0, k) } } - return logits + slices.Reverse(heap) + + ts = heap + return ts +} + +// topP limits tokens to those with cumulative probability p +func topP(ts []token, p float32) []token { + if p == 1.0 { + return ts + } + + // Find cutoff index where cumulative sum exceeds p + var sum float32 + for i, t := range ts { + sum += t.value + if sum > float32(p) { + ts = ts[:i+1] + return ts + } + } + + return ts +} + +// minP limits tokens to those with cumulative probability p +func minP(ts []token, p float32) []token { + if p == 1.0 { + return ts + } + + maxProb := float32(math.Inf(-1)) + for _, token := range ts { + if token.value > maxProb { + maxProb = token.value + } + } + + threshold := maxProb * float32(p) + + // Filter tokens in-place + validTokens := ts[:0] + for i, token := range ts { + if token.value >= threshold { + validTokens = append(validTokens, ts[i]) + } + } + + ts = validTokens + return ts +} + +// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 +// sortLogits sorts implementation to sort tokens by logits using counting sort +// counting sort is faster than built-in sort for this use case +func sortLogits(tokens []token) { + if len(tokens) <= 1 { + return + } + + // Find max/min in a single pass + minLogit, maxLogit := tokens[0].value, tokens[0].value + for _, t := range tokens[1:] { + if t.value < minLogit { + minLogit = t.value + } else if t.value > maxLogit { + maxLogit = t.value + } + } + + // Calculate scaling to map to uint32 range + logitRange := maxLogit - minLogit + if logitRange < 1e-6 { + return // All values effectively equal + } + + // Count frequencies directly from tokens + const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity + var counts [256]int // For first byte + + // First pass: count frequencies + for _, t := range tokens { + // Map to [0, maxInt] range + score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt) + counts[score>>16]++ + } + + // Calculate offsets + var offset int + for i := range counts { + count := counts[i] + counts[i] = offset + offset += count + } + + // Second pass: place elements in correct position + output := make([]token, len(tokens)) + // Track current positions + countsCopy := counts + + for i, t := range tokens { + score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt) + + pos := countsCopy[score>>16] + countsCopy[score>>16]++ + output[len(tokens)-1-pos] = tokens[i] + } + + copy(tokens, output) } diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 05f76a274..81e8849b7 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -4,77 +4,175 @@ import ( "math" "math/rand/v2" "testing" - - "github.com/google/go-cmp/cmp" ) -func TestTemperature(t *testing.T) { - got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0}) - want := []float64{-4, -10, 0, -14, -6, -12, -8} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) +// Helper to convert float64 slice to logit slice +func toTokens(values []float64) []token { + tokens := make([]token, len(values)) + for i, v := range values { + tokens[i] = token{ + id: int32(i), + value: float32(v), + } + } + return tokens +} + +// Helper to compare logit slices +func compareLogits(t *testing.T, name string, want []float64, got []token) { + t.Helper() + if len(want) != len(got) { + t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got)) + return + } + for i := range want { + if math.Abs(float64(got[i].value)-want[i]) > 1e-6 { + t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value) + } } } -func TestSoftmax(t *testing.T) { - got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4}) +func TestTemperatureAndSoftmax(t *testing.T) { + input := []float64{1, 4, -2, 0} + got := temperature(toTokens(input), 0.5) - want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("probs mismatch (-want +got):\n%s", diff) + // Check probabilities sum to 1 + var sum float32 + for _, token := range got { + sum += token.value + } + if math.Abs(float64(sum)-1.0) > 1e-6 { + t.Errorf("probabilities don't sum to 1: got %f", sum) + } + + got = temperature(toTokens(input), 1) + // Check probabilities sum to 1 + sum = 0.0 + for _, token := range got { + sum += token.value + } + if math.Abs(float64(sum)-1.0) > 1e-6 { + t.Errorf("probabilities don't sum to 1: got %f", sum) } } func TestTopK(t *testing.T) { - got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) - } + input := []float64{-3, -2, -1, 0, 1, 2, 4} - got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - - want = []float64{-3, -2, -1, 0, 1, 2, 4} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) + // Test k=3 + got := topK(toTokens(input), 3) + if len(got) != 3 { + t.Errorf("topK(3): wrong length: want 3, got %d", len(got)) } + // Should keep highest 3 values: 4, 2, 1 + want := []float64{4, 2, 1} + compareLogits(t, "topK(3)", want, got) + + // Test k > len + got = topK(toTokens(input), 10) + compareLogits(t, "topK(10)", input, got) } func TestTopP(t *testing.T) { - got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) + input := []float64{-3, -2, -1, 0, 1, 2, 4} + tokens := toTokens(input) + + // First apply temperature and softmax to get probabilities + tokens = temperature(tokens, 1) + sortLogits(tokens) + + // Then apply topP + got := topP(tokens, 0.95) + + // Should keep tokens until cumsum > 0.95 + if len(got) > 3 { + t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) + t.Logf("got: %v", got) } } func TestMinP(t *testing.T) { - got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3}) - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) + input := []float64{-3, -2, -1, 0, 1, 2, 4, 3} + tokens := toTokens(input) + + // First apply temperature and softmax + tokens = temperature(tokens, 1) + + // Then apply minP + got := minP(tokens, 0.2) + + // Should keep tokens with prob >= 0.2 * max_prob + if len(got) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) } } -func BenchmarkTransform(b *testing.B) { - transforms := map[string]Transform{ - "Temperature": Temperature(0.5), - "TopK": TopK(10), - "TopP": TopP(0.9), - "MinP": MinP(0.2), +func TestSortLogits(t *testing.T) { + input := []float64{3, 1, 4, 2, -1, 0, -2} + tokens := toTokens(input) + + sortLogits(tokens) + + for i := 1; i < len(tokens); i++ { + if tokens[i].value > tokens[i-1].value { + t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f", + i, tokens[i].value, tokens[i-1].value) + } } - logits := make([]float64, 1<<16) - for i := range logits { - logits[i] = rand.Float64() - } - - for name, transform := range transforms { - b.Run(name, func(b *testing.B) { - b.ResetTimer() - for range b.N { - transform.Apply(logits) - } - }) - } + want := []float64{4, 3, 2, 1, 0, -1, -2} + compareLogits(t, "sortLogits", want, tokens) +} + +func BenchmarkTransforms(b *testing.B) { + // Generate random logits + tokens := make([]token, 1<<16) + for i := range tokens { + tokens[i] = token{ + id: int32(i), + value: rand.Float32(), + } + } + + tokensCopy := make([]token, len(tokens)) + + b.Run("Temperature", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + temperature(tokensCopy, 0.5) + } + }) + + b.Run("TopK", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + topK(tokensCopy, 10) + } + }) + + b.Run("TopP", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + topP(tokensCopy, 0.9) + } + }) + + b.Run("MinP", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + minP(tokensCopy, 0.2) + } + }) + + b.Run("SortTokens", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + sortLogits(tokensCopy) + } + }) } diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 62930d7f2..60485df85 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -80,13 +80,14 @@ function checkEnv() { function buildOllama() { + mkdir -Force -path "${script:DIST_DIR}\" if ($script:ARCH -ne "arm64") { Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0 & cmake --fresh --preset CPU --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --build --preset CPU --parallel $script:JOBS + & cmake --build --preset CPU --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component CPU --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -101,7 +102,7 @@ function buildOllama() { # to avoid 2022 (or newer) from being used as the default & cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --build --preset "CUDA 11" --parallel $script:JOBS + & cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -112,7 +113,7 @@ function buildOllama() { write-host "Building CUDA v12 backend libraries" & cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --build --preset "CUDA 12" --parallel $script:JOBS + & cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -131,7 +132,7 @@ function buildOllama() { $env:HIPCXX="" $env:HIP_PLATFORM="" $env:CMAKE_PREFIX_PATH="" - & cmake --build --preset "ROCm" --parallel $script:JOBS + & cmake --build --preset "ROCm" --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "HIP" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} diff --git a/scripts/install.sh b/scripts/install.sh index 9e146e508..9c232400f 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -77,11 +77,12 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then fi status "Installing ollama to $OLLAMA_INSTALL_DIR" $SUDO install -o0 -g0 -m755 -d $BINDIR -$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR" +$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama" status "Downloading Linux ${ARCH} bundle" curl --fail --show-error --location --progress-bar \ "https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \ $SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR" + if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then status "Making ollama accessible in the PATH in $BINDIR" $SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama" diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index e4c36d7d8..423a6ad23 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -24,8 +24,10 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -43,9 +45,9 @@ import ( // Errors var ( - // ErrManifestNotFound is returned when a manifest is not found in the + // ErrModelNotFound is returned when a manifest is not found in the // cache or registry. - ErrManifestNotFound = errors.New("manifest not found") + ErrModelNotFound = errors.New("model not found") // ErrManifestInvalid is returned when a manifest found in a local or // remote cache is invalid. @@ -53,7 +55,7 @@ var ( // ErrMissingModel is returned when the model part of a name is missing // or invalid. - ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") + ErrNameInvalid = errors.New("invalid or missing name") // ErrCached is passed to [Trace.PushUpdate] when a layer already // exists. It is a non-fatal error and is never returned by [Registry.Push]. @@ -72,19 +74,22 @@ const ( DefaultMaxChunkSize = 8 << 20 ) -// DefaultCache returns a new disk cache for storing models. If the -// OLLAMA_MODELS environment variable is set, it uses that directory; -// otherwise, it uses $HOME/.ollama/models. -func DefaultCache() (*blob.DiskCache, error) { +var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) { dir := os.Getenv("OLLAMA_MODELS") if dir == "" { - home, err := os.UserHomeDir() - if err != nil { - return nil, err - } + home, _ := os.UserHomeDir() + home = cmp.Or(home, ".") dir = filepath.Join(home, ".ollama", "models") } return blob.Open(dir) +}) + +// DefaultCache returns the default cache used by the registry. It is +// configured from the OLLAMA_MODELS environment variable, or defaults to +// $HOME/.ollama/models, or, if an error occurs obtaining the home directory, +// it uses the current working directory. +func DefaultCache() (*blob.DiskCache, error) { + return defaultCache() } // Error is the standard error returned by Ollama APIs. It can represent a @@ -109,7 +114,18 @@ type Error struct { } func (e *Error) Error() string { - return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message) + var b strings.Builder + b.WriteString("registry responded with status ") + b.WriteString(strconv.Itoa(e.Status)) + if e.Code != "" { + b.WriteString(": code ") + b.WriteString(e.Code) + } + if e.Message != "" { + b.WriteString(": ") + b.WriteString(e.Message) + } + return b.String() } func (e *Error) LogValue() slog.Value { @@ -167,6 +183,10 @@ func CompleteName(name string) string { // Registry is a client for performing push and pull operations against an // Ollama registry. type Registry struct { + // Cache is the cache used to store models. If nil, [DefaultCache] is + // used. + Cache *blob.DiskCache + // UserAgent is the User-Agent header to send with requests to the // registry. If empty, the User-Agent is determined by HTTPClient. UserAgent string @@ -205,10 +225,28 @@ type Registry struct { // It is only used when a layer is larger than [MaxChunkingThreshold]. MaxChunkSize int64 - // NameMask, if set, is the name used to convert non-fully qualified - // names to fully qualified names. If empty, the default mask - // ("registry.ollama.ai/library/_:latest") is used. - NameMask string + // Mask, if set, is the name used to convert non-fully qualified names + // to fully qualified names. If empty, [DefaultMask] is used. + Mask string +} + +func (r *Registry) cache() (*blob.DiskCache, error) { + if r.Cache != nil { + return r.Cache, nil + } + return defaultCache() +} + +func (r *Registry) parseName(name string) (names.Name, error) { + mask := defaultMask + if r.Mask != "" { + mask = names.Parse(r.Mask) + } + n := names.Merge(names.Parse(name), mask) + if !n.IsFullyQualified() { + return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name) + } + return n, nil } // DefaultRegistry returns a new Registry configured from the environment. The @@ -243,52 +281,6 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } -type PushParams struct { - // From is an optional destination name for the model. If empty, the - // destination name is the same as the source name. - From string -} - -// parseName parses name using [names.ParseExtended] and then merges the name with the -// default name, and checks that the name is fully qualified. If a digest is -// present, it parse and returns it with the other fields as their zero values. -// -// It returns an error if the name is not fully qualified, or if the digest, if -// any, is invalid. -// -// The scheme is returned as provided by [names.ParseExtended]. -func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) { - maskName := defaultMask - if mask != "" { - maskName = names.Parse(mask) - if !maskName.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask) - } - } - scheme, n, ds := names.ParseExtended(s) - if !n.IsValid() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - n = names.Merge(n, maskName) - if ds != "" { - // Digest is present. Validate it. - d, err = blob.ParseDigest(ds) - if err != nil { - return "", names.Name{}, blob.Digest{}, err - } - } - - // The name check is deferred until after the digest check because we - // say that digests take precedence over names, and so should there - // errors when being parsed. - if !n.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - - scheme = cmp.Or(scheme, "https") - return scheme, n, d, nil -} - func (r *Registry) maxStreams() int { n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) @@ -308,13 +300,24 @@ func (r *Registry) maxChunkSize() int64 { return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) } +type PushParams struct { + // From is an optional destination name for the model. If empty, the + // destination name is the same as the source name. + From string +} + // Push pushes the model with the name in the cache to the remote registry. -func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { +func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { if p == nil { p = &PushParams{} } - m, err := r.ResolveLocal(c, cmp.Or(p.From, name)) + c, err := r.cache() + if err != nil { + return err + } + + m, err := r.ResolveLocal(cmp.Or(p.From, name)) if err != nil { return err } @@ -337,7 +340,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * t := traceFromContext(ctx) - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := r.parseNameExtended(name) if err != nil { // This should never happen since ResolveLocal should have // already validated the name. @@ -363,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * n.Model(), l.Digest, ) - res, err := r.doOK(ctx, "POST", startURL, nil) + res, err := r.send(ctx, "POST", startURL, nil) if err != nil { return err } @@ -387,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * } req.ContentLength = l.Size - res, err = doOK(r.client(), req) + res, err = sendRequest(r.client(), req) if err == nil { res.Body.Close() } @@ -407,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * n.Model(), n.Tag(), ) - res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data)) + res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data)) if err == nil { res.Body.Close() } @@ -430,8 +433,8 @@ func canRetry(err error) bool { // chunks of the specified size, and then reassembled and verified. This is // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". -func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { - scheme, n, _, err := parseName(name, r.NameMask) +func (r *Registry) Pull(ctx context.Context, name string) error { + scheme, n, _, err := r.parseNameExtended(name) if err != nil { return err } @@ -444,6 +447,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err return fmt.Errorf("%w: no layers", ErrManifestInvalid) } + c, err := r.cache() + if err != nil { + return err + } + exists := func(l *Layer) bool { info, err := c.Get(l.Digest) return err == nil && info.Size == l.Size @@ -451,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err t := traceFromContext(ctx) - var g errgroup.Group + g, ctx := errgroup.WithContext(ctx) g.SetLimit(r.maxStreams()) - for _, l := range m.Layers { + layers := m.Layers + if m.Config != nil && m.Config.Digest.IsValid() { + layers = append(layers, m.Config) + } + + for _, l := range layers { if exists(l) { t.update(l, l.Size, ErrCached) continue @@ -471,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err if l.Size <= r.maxChunkingThreshold() { g.Go(func() error { - res, err := doOK(r.client(), req) + // TODO(bmizerany): retry/backoff like below in + // the chunking case + res, err := sendRequest(r.client(), req) if err != nil { return err } @@ -497,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // fire an initial request to get the final URL and // then use that URL for the chunk requests. req.Header.Set("Range", "bytes=0-0") - res, err := doOK(r.client(), req) + res, err := sendRequest(r.client(), req) if err != nil { return err } res.Body.Close() req = res.Request.WithContext(req.Context()) - streamNo := 0 - tws := make([]*bufio.Writer, r.maxStreams()-1) + wp := writerPool{size: r.maxChunkSize()} + for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { + if ctx.Err() != nil { + break + } + ticket := q.Take() - bufIdx := streamNo % len(tws) - streamNo++ g.Go(func() (err error) { defer func() { if err != nil { @@ -523,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err if err != nil { return err } - err := func() error { req := req.Clone(req.Context()) req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) - res, err := doOK(r.client(), req) + res, err := sendRequest(r.client(), req) if err != nil { return err } defer res.Body.Close() - tw := tws[bufIdx] - if tw == nil { - tw = bufio.NewWriterSize(nil, int(r.maxChunkSize())) - tws[bufIdx] = tw - } + tw := wp.get() tw.Reset(ticket) - defer tw.Reset(nil) // release ticket + defer wp.put(tw) _, err = io.CopyN(tw, res.Body, chunk.Size()) if err != nil { @@ -581,8 +593,12 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. -func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { - _, n, _, err := parseName(name, r.NameMask) +func (r *Registry) Unlink(name string) (ok bool, _ error) { + n, err := r.parseName(name) + if err != nil { + return false, err + } + c, err := r.cache() if err != nil { return false, err } @@ -594,6 +610,9 @@ type Manifest struct { Name string `json:"-"` // the canonical name of the model Data []byte `json:"-"` // the raw data of the manifest Layers []*Layer `json:"layers"` + + // For legacy reasons, we still have to download the config layer. + Config *Layer `json:"config"` } var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") @@ -657,14 +676,18 @@ type Layer struct { Size int64 `json:"size"` } -// ResolveLocal resolves a name to a Manifest in the local cache. The name is -// parsed using [names.ParseExtended] but the scheme is ignored. -func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { - _, n, d, err := parseName(name, r.NameMask) +// ResolveLocal resolves a name to a Manifest in the local cache. +func (r *Registry) ResolveLocal(name string) (*Manifest, error) { + _, n, d, err := r.parseNameExtended(name) + if err != nil { + return nil, err + } + c, err := r.cache() if err != nil { return nil, err } if !d.IsValid() { + // No digest, so resolve the manifest by name. d, err = c.Resolve(n.String()) if err != nil { return nil, err @@ -673,7 +696,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro data, err := os.ReadFile(c.GetFile(d)) if err != nil { if errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name) + return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name) } return nil, err } @@ -686,7 +709,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro // Resolve resolves a name to a Manifest in the remote registry. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { - scheme, n, d, err := parseName(name, r.NameMask) + scheme, n, d, err := r.parseNameExtended(name) if err != nil { return nil, err } @@ -696,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) } - res, err := r.doOK(ctx, "GET", manifestURL, nil) + res, err := r.send(ctx, "GET", manifestURL, nil) if err != nil { return nil, err } @@ -721,7 +744,7 @@ func (r *Registry) client() *http.Client { } // newRequest constructs a new request, ready to use, with the given method, -// url, and body, presigned with client Key and UserAgent. +// url, and body, pre-signed with client [Key] and [UserAgent]. func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { @@ -740,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R return req, nil } -// doOK makes a request with the given client and request, and returns the +// sendRequest makes a request with the given client and request, and returns the // response if the status code is 200. If the status code is not 200, an Error // is parsed from the response body and returned. If any other error occurs, it // is returned. -func doOK(c *http.Client, r *http.Request) (*http.Response, error) { +func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("request error %s: %w", r.URL, err) + } + }() + if r.URL.Scheme == "https+insecure" { // TODO(bmizerany): clone client.Transport, set // InsecureSkipVerify, etc. @@ -787,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) { // Use the raw body if we can't parse it as an error object. re.Message = string(out) } + + // coerce MANIFEST_UNKNOWN to ErrManifestNotFound + if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") { + return nil, ErrModelNotFound + } + re.Status = res.StatusCode return nil, &re } return res, nil } -// doOK is a convenience method for making a request with newRequest and -// passing it to doOK with r.client(). -func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +// send is a convenience method for making a request with newRequest and +// passing it to send with r.client(). +func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { req, err := r.newRequest(ctx, method, path, body) if err != nil { return nil, err } - return doOK(r.client(), req) + return sendRequest(r.client(), req) } // makeAuthToken creates an Ollama auth token for the given private key. @@ -869,3 +904,114 @@ func maybeUnexpectedEOF(err error) error { } return err } + +type publicError struct { + wrapped error + message string +} + +func withPublicMessagef(err error, message string, args ...any) error { + return publicError{wrapped: err, message: fmt.Sprintf(message, args...)} +} + +func (e publicError) Error() string { return e.message } +func (e publicError) Unwrap() error { return e.wrapped } + +var supportedSchemes = []string{ + "http", + "https", + "https+insecure", +} + +var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) + +// parseNameExtended parses and validates an extended name, returning the scheme, name, +// and digest. +// +// If the scheme is empty, scheme will be "https". If an unsupported scheme is +// given, [ErrNameInvalid] wrapped with a display friendly message is returned. +// +// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly +// message is returned. +// +// If the name is not, once merged with the mask, fully qualified, +// [ErrNameInvalid] wrapped with a display friendly message is returned. +func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) { + scheme, name, digest := splitExtended(s) + scheme = cmp.Or(scheme, "https") + if !slices.Contains(supportedSchemes, scheme) { + err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) + return "", names.Name{}, blob.Digest{}, err + } + + var d blob.Digest + if digest != "" { + var err error + d, err = blob.ParseDigest(digest) + if err != nil { + err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest) + return "", names.Name{}, blob.Digest{}, err + } + if name == "" { + // We have can resolve a manifest from a digest only, + // so skip name validation and return the scheme and + // digest. + return scheme, names.Name{}, d, nil + } + } + + n, err := r.parseName(name) + if err != nil { + return "", names.Name{}, blob.Digest{}, err + } + return scheme, n, d, nil +} + +// splitExtended splits an extended name string into its scheme, name, and digest +// parts. +// +// Examples: +// +// http://ollama.com/bmizerany/smol:latest@digest +// https://ollama.com/bmizerany/smol:latest +// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. +// model@digest +// @digest +func splitExtended(s string) (scheme, name, digest string) { + i := strings.Index(s, "://") + if i >= 0 { + scheme = s[:i] + s = s[i+3:] + } + i = strings.LastIndex(s, "@") + if i >= 0 { + digest = s[i+1:] + s = s[:i] + } + return scheme, s, digest +} + +type writerPool struct { + size int64 // set by the caller + + mu sync.Mutex + ws []*bufio.Writer +} + +func (p *writerPool) get() *bufio.Writer { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.ws) == 0 { + return bufio.NewWriterSize(nil, int(p.size)) + } + w := p.ws[len(p.ws)-1] + p.ws = p.ws[:len(p.ws)-1] + return w +} + +func (p *writerPool) put(w *bufio.Writer) { + p.mu.Lock() + defer p.mu.Unlock() + w.Reset(nil) + p.ws = append(p.ws, w) +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index af898c268..8f4e1604f 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -2,6 +2,7 @@ package ollama import ( "bytes" + "cmp" "context" "encoding/json" "errors" @@ -72,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error // To simulate a network error, pass a handler that returns a 499 status code. func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { t.Helper() + c, err := blob.Open(t.TempDir()) if err != nil { t.Fatal(err) @@ -84,14 +86,15 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } } - rc := &Registry{ + r := &Registry{ + Cache: c, HTTPClient: &http.Client{ Transport: recordRoundTripper(h), }, } link := func(name string, manifest string) { - _, n, _, err := parseName(name, rc.NameMask) + n, err := r.parseName(name) if err != nil { panic(err) } @@ -122,7 +125,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) link("invalid", "!!!!!") - return rc, c + return r, c } func okHandler(w http.ResponseWriter, r *http.Request) { @@ -145,84 +148,61 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest { return d } -func TestRegistryPushInvalidNames(t *testing.T) { - rc, c := newClient(t, nil) - - cases := []struct { - name string - err error - }{ - {"", ErrNameInvalid}, - {"@", ErrNameInvalid}, - {"@x", blob.ErrInvalidDigest}, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - // Create a new registry and push a new image. - err := rc.Push(t.Context(), c, tt.name, nil) - if !errors.Is(err, tt.err) { - t.Errorf("err = %v; want %v", err, tt.err) - } - }) - } -} - func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} return WithTrace(ctx, t), t } func TestPushZero(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "empty", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "empty", nil) if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want %v", err, ErrManifestInvalid) } } func TestPushSingle(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "single", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "single", nil) testutil.Check(t, err) } func TestPushMultiple(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "multiple", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "multiple", nil) testutil.Check(t, err) } func TestPushNotFound(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { t.Errorf("unexpected request: %v", r) }) - err := rc.Push(t.Context(), c, "notfound", nil) + err := rc.Push(t.Context(), "notfound", nil) if !errors.Is(err, fs.ErrNotExist) { t.Errorf("err = %v; want %v", err, fs.ErrNotExist) } } func TestPushNullLayer(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Push(t.Context(), c, "null", nil) + rc, _ := newClient(t, nil) + err := rc.Push(t.Context(), "null", nil) if err == nil || !strings.Contains(err.Error(), "invalid manifest") { t.Errorf("err = %v; want invalid manifest", err) } } func TestPushSizeMismatch(t *testing.T) { - rc, c := newClient(t, nil) + rc, _ := newClient(t, nil) ctx, _ := withTraceUnexpected(t.Context()) - got := rc.Push(ctx, c, "sizemismatch", nil) + got := rc.Push(ctx, "sizemismatch", nil) if got == nil || !strings.Contains(got.Error(), "size mismatch") { t.Errorf("err = %v; want size mismatch", got) } } func TestPushInvalid(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Push(t.Context(), c, "invalid", nil) + rc, _ := newClient(t, nil) + err := rc.Push(t.Context(), "invalid", nil) if err == nil || !strings.Contains(err.Error(), "invalid manifest") { t.Errorf("err = %v; want invalid manifest", err) } @@ -230,7 +210,7 @@ func TestPushInvalid(t *testing.T) { func TestPushExistsAtRemote(t *testing.T) { var pushed bool - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/uploads/") { if !pushed { // First push. Return an uploadURL. @@ -258,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) { check := testutil.Checker(t) - err := rc.Push(ctx, c, "single", nil) + err := rc.Push(ctx, "single", nil) check(err) if !errors.Is(errors.Join(errs...), nil) { t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) } - err = rc.Push(ctx, c, "single", nil) + err = rc.Push(ctx, "single", nil) check(err) } func TestPushRemoteError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { w.WriteHeader(500) io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) return } }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) checkErrCode(t, got, 500, "blob_error") } func TestPushLocationError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Location", ":///x") w.WriteHeader(http.StatusAccepted) }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) wantContains := "invalid upload URL" if got == nil || !strings.Contains(got.Error(), wantContains) { t.Errorf("err = %v; want to contain %v", got, wantContains) @@ -294,14 +274,14 @@ func TestPushLocationError(t *testing.T) { } func TestPushUploadRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if r.Host == "blob.store" { w.WriteHeader(499) // force RoundTrip error on upload return } w.Header().Set("Location", "http://blob.store/blobs/123") }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) if !errors.Is(got, errRoundTrip) { t.Errorf("got = %v; want %v", got, errRoundTrip) } @@ -317,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) { os.Remove(c.GetFile(l.Digest)) }, }) - got := rc.Push(ctx, c, "single", nil) + got := rc.Push(ctx, "single", nil) if !errors.Is(got, fs.ErrNotExist) { t.Errorf("got = %v; want fs.ErrNotExist", got) } } func TestPushCommitRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { panic("unexpected") } w.WriteHeader(499) // force RoundTrip error }) - err := rc.Push(t.Context(), c, "zero", nil) + err := rc.Push(t.Context(), "zero", nil) if !errors.Is(err, errRoundTrip) { t.Errorf("err = %v; want %v", err, errRoundTrip) } @@ -344,8 +324,8 @@ func checkNotExist(t *testing.T, err error) { } func TestRegistryPullInvalidName(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Pull(t.Context(), c, "://") + rc, _ := newClient(t, nil) + err := rc.Pull(t.Context(), "://") if !errors.Is(err, ErrNameInvalid) { t.Errorf("err = %v; want %v", err, ErrNameInvalid) } @@ -360,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) { } for _, resp := range cases { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, resp) }) - err := rc.Pull(t.Context(), c, "x") + err := rc.Pull(t.Context(), "x") if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want invalid manifest", err) } @@ -386,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) { }) // Confirm that the layer does not exist locally - _, err := rc.ResolveLocal(c, "model") + _, err := rc.ResolveLocal("model") checkNotExist(t, err) _, err = c.Get(d) checkNotExist(t, err) - err = rc.Pull(t.Context(), c, "model") + err = rc.Pull(t.Context(), "model") check(err) mw, err := rc.Resolve(t.Context(), "model") check(err) - mg, err := rc.ResolveLocal(c, "model") + mg, err := rc.ResolveLocal("model") check(err) if !reflect.DeepEqual(mw, mg) { t.Errorf("mw = %v; mg = %v", mw, mg) @@ -422,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) { func TestRegistryPullCached(t *testing.T) { cached := blob.DigestFromBytes("exists") - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { w.WriteHeader(499) // should not be called return @@ -445,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - err := rc.Pull(ctx, c, "single") + err := rc.Pull(ctx, "single") testutil.Check(t, err) want := []int64{6} @@ -458,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) { } func TestRegistryPullManifestNotFound(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }) - err := rc.Pull(t.Context(), c, "notfound") + err := rc.Pull(t.Context(), "notfound") checkErrCode(t, err, 404, "") } func TestRegistryPullResolveRemoteError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) }) - err := rc.Pull(t.Context(), c, "single") + err := rc.Pull(t.Context(), "single") checkErrCode(t, err, 500, "an_error") } func TestRegistryPullResolveRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/manifests/") { w.WriteHeader(499) // force RoundTrip error return } }) - err := rc.Pull(t.Context(), c, "single") + err := rc.Pull(t.Context(), "single") if !errors.Is(err, errRoundTrip) { t.Errorf("err = %v; want %v", err, errRoundTrip) } @@ -534,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { // Check that we pull all layers that we can. - err := rc.Pull(ctx, c, "mixed") + err := rc.Pull(ctx, "mixed") if err != nil { t.Fatal(err) } @@ -552,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { } func TestRegistryPullChunking(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) if r.URL.Host != "blob.store" { // The production registry redirects to the blob store. @@ -590,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) { }, }) - err := rc.Pull(ctx, c, "remote") + err := rc.Pull(ctx, "remote") testutil.Check(t, err) want := []int64{0, 3, 6} @@ -622,13 +602,13 @@ func TestInsecureSkipVerify(t *testing.T) { })) defer s.Close() - const name = "ollama.com/library/insecure" + const name = "library/insecure" var rc Registry url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) _, err := rc.Resolve(t.Context(), url) if err == nil || !strings.Contains(err.Error(), "failed to verify") { - t.Errorf("err = %v; want cert verifiction failure", err) + t.Errorf("err = %v; want cert verification failure", err) } url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) @@ -724,3 +704,115 @@ func TestErrorUnmarshal(t *testing.T) { }) } } + +// TestParseNameErrors tests that parseName returns errors messages with enough +// detail for users to debug naming issues they may encounter. Previous to this +// test, the error messages were not very helpful and each problem was reported +// as the same message. +// +// It is only for testing error messages, not that all invalids and valids are +// covered. Those are in other tests for names.Name and blob.Digest. +func TestParseNameExtendedErrors(t *testing.T) { + cases := []struct { + name string + err error + want string + }{} + + var r Registry + for _, tt := range cases { + _, _, _, err := r.parseNameExtended(tt.name) + if !errors.Is(err, tt.err) { + t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) + } + if err != nil && !strings.Contains(err.Error(), tt.want) { + t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want) + } + } +} + +func TestParseNameExtended(t *testing.T) { + cases := []struct { + in string + scheme string + name string + digest string + err string + }{ + {in: "http://m", scheme: "http", name: "m"}, + {in: "https+insecure://m", scheme: "https+insecure", name: "m"}, + {in: "http+insecure://m", err: "unsupported scheme"}, + + {in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"}, + + {in: "", err: "invalid or missing name"}, + {in: "m", scheme: "https", name: "m"}, + {in: "://", err: "invalid or missing name"}, + {in: "@sha256:deadbeef", err: "invalid digest"}, + {in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + var r Registry + scheme, n, digest, err := r.parseNameExtended(tt.in) + if err != nil { + if tt.err == "" { + t.Errorf("err = %v; want nil", err) + } else if !strings.Contains(err.Error(), tt.err) { + t.Errorf("err = %v; want %q", err, tt.err) + } + } else if tt.err != "" { + t.Errorf("err = nil; want %q", tt.err) + } + if err == nil && !n.IsFullyQualified() { + t.Errorf("name = %q; want fully qualified", n) + } + + if scheme != tt.scheme { + t.Errorf("scheme = %q; want %q", scheme, tt.scheme) + } + + // smoke-test name is superset of tt.name + if !strings.Contains(n.String(), tt.name) { + t.Errorf("name = %q; want %q", n, tt.name) + } + + tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String()) + if digest.String() != tt.digest { + t.Errorf("digest = %q; want %q", digest, tt.digest) + } + }) + } +} + +func TestUnlink(t *testing.T) { + t.Run("found by name", func(t *testing.T) { + rc, _ := newClient(t, nil) + + // confirm linked + _, err := rc.ResolveLocal("single") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // unlink + _, err = rc.Unlink("single") + testutil.Check(t, err) + + // confirm unlinked + _, err = rc.ResolveLocal("single") + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("err = %v; want fs.ErrNotExist", err) + } + }) + t.Run("not found by name", func(t *testing.T) { + rc, _ := newClient(t, nil) + ok, err := rc.Unlink("manifestNotFound") + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("expected not found") + } + }) +} diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index 8e53040ad..69435c406 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -6,13 +6,20 @@ import ( // Trace is a set of functions that are called to report progress during blob // downloads and uploads. +// +// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push] +// and [Registry.Pull]. type Trace struct { // Update is called during [Registry.Push] and [Registry.Pull] to // report the progress of blob uploads and downloads. // - // It is called once at the beginning of the download with a zero n and - // then once per read operation with the number of bytes read so far, - // and an error if any. + // The n argument is the number of bytes transferred so far, and err is + // any error that has occurred. If n == 0, and err is nil, the download + // or upload has just started. If err is [ErrCached], the download or + // upload has been skipped because the blob is already present in the + // local cache or remote registry, respectively. Otherwise, if err is + // non-nil, the download or upload has failed. When l.Size == n, and + // err is nil, the download or upload has completed. // // A function assigned must be safe for concurrent use. The function is // called synchronously and so should not block or take long to run. diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go index c21e71d59..6976927c7 100644 --- a/server/internal/cmd/opp/opp.go +++ b/server/internal/cmd/opp/opp.go @@ -63,25 +63,28 @@ func main() { } flag.Parse() - c, err := ollama.DefaultCache() - if err != nil { - log.Fatal(err) - } - - rc, err := ollama.DefaultRegistry() - if err != nil { - log.Fatal(err) - } - ctx := context.Background() - err = func() error { + err := func() error { switch cmd := flag.Arg(0); cmd { case "pull": - return cmdPull(ctx, rc, c) + rc, err := ollama.DefaultRegistry() + if err != nil { + log.Fatal(err) + } + + return cmdPull(ctx, rc) case "push": - return cmdPush(ctx, rc, c) + rc, err := ollama.DefaultRegistry() + if err != nil { + log.Fatal(err) + } + return cmdPush(ctx, rc) case "import": + c, err := ollama.DefaultCache() + if err != nil { + log.Fatal(err) + } return cmdImport(ctx, c) default: if cmd == "" { @@ -99,7 +102,7 @@ func main() { } } -func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { +func cmdPull(ctx context.Context, rc *ollama.Registry) error { model := flag.Arg(1) if model == "" { flag.Usage() @@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error errc := make(chan error) go func() { - errc <- rc.Pull(ctx, c, model) + errc <- rc.Pull(ctx, model) }() t := time.NewTicker(time.Second) @@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error } } -func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { +func cmdPush(ctx context.Context, rc *ollama.Registry) error { args := flag.Args()[1:] flag := flag.NewFlagSet("push", flag.ExitOnError) flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") @@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error } from := cmp.Or(*flagFrom, model) - m, err := rc.ResolveLocal(c, from) + m, err := rc.ResolveLocal(from) if err != nil { return err } @@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error }, }) - return rc.Push(ctx, c, model, &ollama.PushParams{ + return rc.Push(ctx, model, &ollama.PushParams{ From: from, }) } diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go index bb8438a78..11ace22a8 100644 --- a/server/internal/internal/backoff/backoff_test.go +++ b/server/internal/internal/backoff/backoff_test.go @@ -1,3 +1,5 @@ +//go:build goexperiment.synctest + package backoff import ( diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go index 361cce76f..f0a1185dc 100644 --- a/server/internal/internal/names/name.go +++ b/server/internal/internal/names/name.go @@ -8,7 +8,7 @@ import ( "github.com/ollama/ollama/server/internal/internal/stringsx" ) -const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /: +const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // //: type Name struct { // Make incomparable to enfoce use of Compare / Equal for @@ -25,19 +25,12 @@ type Name struct { // format of a valid name string is: // // s: -// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } // { host } "/" { namespace } "/" { model } ":" { tag } -// { host } "/" { namespace } "/" { model } "@" { digest } // { host } "/" { namespace } "/" { model } -// { namespace } "/" { model } ":" { tag } "@" { digest } // { namespace } "/" { model } ":" { tag } -// { namespace } "/" { model } "@" { digest } // { namespace } "/" { model } -// { model } ":" { tag } "@" { digest } // { model } ":" { tag } -// { model } "@" { digest } // { model } -// "@" { digest } // host: // pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* // length: [1, 350] @@ -50,9 +43,6 @@ type Name struct { // tag: // pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* // length: [1, 80] -// digest: -// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* -// length: [1, 80] // // The name returned is not guaranteed to be valid. If it is not valid, the // field values are left in an undefined state. Use [Name.IsValid] to check @@ -82,23 +72,17 @@ func Parse(s string) Name { } } -// ParseExtended parses and returns any scheme, Name, and digest from from s in -// the the form [scheme://][name][@digest]. All parts are optional. -// -// If the scheme is present, it must be followed by "://". The digest is -// prefixed by "@" and comes after the name. The name is parsed using [Parse]. -// -// The scheme and digest are stripped before the name is parsed by [Parse]. -// -// For convience, the scheme is never empty. If the scheme is not present, the -// returned scheme is "https". +// Split splits an extended name string into its scheme, name, and digest +// parts. // // Examples: // // http://ollama.com/bmizerany/smol:latest@digest // https://ollama.com/bmizerany/smol:latest // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. -func ParseExtended(s string) (scheme string, _ Name, digest string) { +// model@digest +// @digest +func Split(s string) (scheme, name, digest string) { i := strings.Index(s, "://") if i >= 0 { scheme = s[:i] @@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) { digest = s[i+1:] s = s[:i] } - return scheme, Parse(s), digest -} - -func FormatExtended(scheme string, n Name, digest string) string { - var b strings.Builder - if scheme != "" { - b.WriteString(scheme) - b.WriteString("://") - } - b.WriteString(n.String()) - if digest != "" { - b.WriteByte('@') - b.WriteString(digest) - } - return b.String() + return scheme, s, digest } // Merge merges two names into a single name. Non-empty host, namespace, and @@ -141,39 +111,68 @@ func Merge(a, b Name) Name { // IsValid returns true if the name is valid. func (n Name) IsValid() bool { - if n.h != "" && !isValidHost(n.h) { + if n.h != "" && !isValidPart(partHost, n.h) { return false } - if n.n != "" && !isValidNamespace(n.n) { + if n.n != "" && !isValidPart(partNamespace, n.n) { return false } - if n.m != "" && !isValidModel(n.m) { + if n.t != "" && !isValidPart(partTag, n.t) { return false } - if n.t != "" && !isValidTag(n.t) { - return false - } - return true + + // at bare minimum, model must be present and valid + return n.m != "" && isValidPart(partModel, n.m) } func (n Name) IsFullyQualified() bool { return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" } -func isValidHost(_ string) bool { - return true // TODO: implement +const ( + partHost = iota + partNamespace + partModel + partTag +) + +func isValidPart(kind int, s string) bool { + maxlen := 80 + if kind == partHost { + maxlen = 350 + } + if len(s) > maxlen { + return false + } + + for i := range s { + if i == 0 { + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + continue + } + switch s[i] { + case '_', '-': + case '.': + if kind == partNamespace { + return false + } + case ':': + if kind != partHost { + return false + } + default: + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + } + } + return true } -func isValidNamespace(_ string) bool { - return true // TODO: implement -} - -func isValidModel(_ string) bool { - return true // TODO: implement -} - -func isValidTag(_ string) bool { - return true // TODO: implement +func isAlphanumericOrUnderscore(c byte) bool { + return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_' } func (n Name) Host() string { return n.h } diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go index 760fec5fa..e3dc5fe3c 100644 --- a/server/internal/internal/names/name_test.go +++ b/server/internal/internal/names/name_test.go @@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) { } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { - scheme, name, digest := ParseExtended(tt.in) - if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { + scheme, name, digest := Split(tt.in) + n := Parse(name) + if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest { t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) } - - // Round trip - if got := FormatExtended(scheme, name, digest); got != tt.in { - t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got) - } }) } } @@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) { junkName = Parse("h/n/m:t") } } + +const ( + part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888" + part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333" +) + +var testCases = map[string]bool{ // name -> valid + "": false, + + "_why/_the/_lucky:_stiff": true, + + // minimal + "h/n/m:t": true, + + "host/namespace/model:tag": true, + "host/namespace/model": true, + "namespace/model": true, + "model": true, + + // long (but valid) + part80 + "/" + part80 + "/" + part80 + ":" + part80: true, + part350 + "/" + part80 + "/" + part80 + ":" + part80: true, + + // too long + part80 + "/" + part80 + "/" + part80 + ":" + part350: false, + "x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false, + + "h/nn/mm:t": true, // bare minimum part sizes + + // unqualified + "m": true, + "n/m:": true, + "h/n/m": true, + "@t": false, + "m@d": false, + + // invalids + "^": false, + "mm:": true, + "/nn/mm": true, + "//": false, // empty model + "//mm": true, + "hh//": false, // empty model + "//mm:@": false, + "00@": false, + "@": false, + + // not starting with alphanum + "-hh/nn/mm:tt": false, + "hh/-nn/mm:tt": false, + "hh/nn/-mm:tt": false, + "hh/nn/mm:-tt": false, + + // smells like a flag + "-h": false, + + // hosts + "host:https/namespace/model:tag": true, + + // colon in non-host part before tag + "host/name:space/model:tag": false, +} + +func TestParseNameValidation(t *testing.T) { + for s, valid := range testCases { + got := Parse(s) + if got.IsValid() != valid { + t.Logf("got: %v", got) + t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid()) + } + } +} diff --git a/server/internal/internal/syncs/line_test.go b/server/internal/internal/syncs/line_test.go index d52160260..94114a565 100644 --- a/server/internal/internal/syncs/line_test.go +++ b/server/internal/internal/syncs/line_test.go @@ -1,3 +1,5 @@ +//go:build goexperiment.synctest + package syncs import ( diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 8eb6daf89..62fefb4c7 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -7,9 +7,12 @@ import ( "cmp" "encoding/json" "errors" + "fmt" "io" "log/slog" "net/http" + "sync" + "time" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" @@ -27,12 +30,15 @@ import ( // directly to the blob disk cache. type Local struct { Client *ollama.Registry // required - Cache *blob.DiskCache // required Logger *slog.Logger // required // Fallback, if set, is used to handle requests that are not handled by // this handler. Fallback http.Handler + + // Prune, if set, is called to prune the local disk cache after a model + // is deleted. + Prune func() error // optional } // serverError is like ollama.Error, but with a Status field for the HTTP @@ -107,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) { switch r.URL.Path { case "/api/delete": return false, s.handleDelete(rec, r) + case "/api/pull": + return false, s.handlePull(rec, r) default: if s.Fallback != nil { s.Fallback.ServeHTTP(rec, r) @@ -199,13 +207,107 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { if err != nil { return err } - ok, err := s.Client.Unlink(s.Cache, p.model()) + ok, err := s.Client.Unlink(p.model()) if err != nil { return err } if !ok { - return &serverError{404, "manifest_not_found", "manifest not found"} + return &serverError{404, "not_found", "model not found"} } + if s.Prune == nil { + return nil + } + return s.Prune() +} + +type progressUpdateJSON struct { + Status string `json:"status"` + Digest blob.Digest `json:"digest,omitempty,omitzero"` + Total int64 `json:"total,omitempty,omitzero"` + Completed int64 `json:"completed,omitempty,omitzero"` +} + +func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { + if r.Method != "POST" { + return errMethodNotAllowed + } + + p, err := decodeUserJSON[*params](r.Body) + if err != nil { + return err + } + + maybeFlush := func() { + fl, _ := w.(http.Flusher) + if fl != nil { + fl.Flush() + } + } + defer maybeFlush() + + var mu sync.Mutex + enc := json.NewEncoder(w) + enc.Encode(progressUpdateJSON{Status: "pulling manifest"}) + + ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + mu.Lock() + defer mu.Unlock() + + // TODO(bmizerany): coalesce these updates; writing per + // update is expensive + enc.Encode(progressUpdateJSON{ + Digest: l.Digest, + Status: "pulling", + Total: l.Size, + Completed: n, + }) + }, + }) + + done := make(chan error, 1) + go func() { + // TODO(bmizerany): continue to support non-streaming responses + done <- s.Client.Pull(ctx, p.model()) + }() + + func() { + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() + for { + select { + case <-t.C: + mu.Lock() + maybeFlush() + mu.Unlock() + case err := <-done: + if err != nil { + var status string + if errors.Is(err, ollama.ErrModelNotFound) { + status = fmt.Sprintf("error: model %q not found", p.model()) + enc.Encode(progressUpdateJSON{Status: status}) + } else { + status = fmt.Sprintf("error: %v", err) + enc.Encode(progressUpdateJSON{Status: status}) + } + return + } + + // These final updates are not strictly necessary, because they have + // already happened at this point. Our pull handler code used to do + // these steps after, not during, the pull, and they were slow, so we + // wanted to provide feedback to users what was happening. For now, we + // keep them to not jar users who are used to seeing them. We can phase + // them out with a new and nicer UX later. One without progress bars + // and digests that no one cares about. + enc.Encode(progressUpdateJSON{Status: "verifying layers"}) + enc.Encode(progressUpdateJSON{Status: "writing manifest"}) + enc.Encode(progressUpdateJSON{Status: "success"}) + return + } + } + }() + return nil } diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 22267ba7d..597e9bd63 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -1,17 +1,27 @@ package registry import ( + "bytes" + "context" "encoding/json" + "fmt" + "io" + "io/fs" + "net" "net/http" "net/http/httptest" "os" "regexp" "strings" + "sync" "testing" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/testutil" + "golang.org/x/tools/txtar" + + _ "embed" ) type panicTransport struct{} @@ -30,7 +40,7 @@ type bytesResetter interface { Reset() } -func newTestServer(t *testing.T) *Local { +func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { t.Helper() dir := t.TempDir() err := os.CopyFS(dir, os.DirFS("testdata/models")) @@ -41,11 +51,26 @@ func newTestServer(t *testing.T) *Local { if err != nil { t.Fatal(err) } - rc := &ollama.Registry{ - HTTPClient: panicOnRoundTrip, + + client := panicOnRoundTrip + if upstreamRegistry != nil { + s := httptest.NewTLSServer(upstreamRegistry) + t.Cleanup(s.Close) + tr := s.Client().Transport.(*http.Transport).Clone() + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", s.Listener.Addr().String()) + } + client = &http.Client{Transport: tr} } + + rc := &ollama.Registry{ + Cache: c, + HTTPClient: client, + Mask: "example.com/library/_:latest", + } + l := &Local{ - Cache: c, Client: rc, Logger: testutil.Slogger(t), } @@ -85,9 +110,9 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) { func TestServerDelete(t *testing.T) { check := testutil.Checker(t) - s := newTestServer(t) + s := newTestServer(t, nil) - _, err := s.Client.ResolveLocal(s.Cache, "smol") + _, err := s.Client.ResolveLocal("smol") check(err) got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`) @@ -95,7 +120,7 @@ func TestServerDelete(t *testing.T) { t.Fatalf("Code = %d; want 200", got.Code) } - _, err = s.Client.ResolveLocal(s.Cache, "smol") + _, err = s.Client.ResolveLocal("smol") if err == nil { t.Fatal("expected smol to have been deleted") } @@ -109,11 +134,8 @@ func TestServerDelete(t *testing.T) { got = s.send(t, "DELETE", "/api/delete", ``) checkErrorResponse(t, got, 400, "bad_request", "empty request body") - got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`) - checkErrorResponse(t, got, 404, "manifest_not_found", "not found") - got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`) - checkErrorResponse(t, got, 400, "bad_request", "invalid name") + checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name") got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body checkErrorResponse(t, got, 404, "not_found", "not found") @@ -130,8 +152,105 @@ func TestServerDelete(t *testing.T) { } } +//go:embed testdata/registry.txt +var registryTXT []byte + +var registryFS = sync.OnceValue(func() fs.FS { + // Txtar gets hung up on \r\n line endings, so we need to convert them + // to \n when parsing the txtar on Windows. + data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n")) + a := txtar.Parse(data) + fmt.Printf("%q\n", a.Comment) + fsys, err := txtar.FS(a) + if err != nil { + panic(err) + } + return fsys +}) + +func TestServerPull(t *testing.T) { + modelsHandler := http.FileServerFS(registryFS()) + s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v2/library/BOOM/manifests/latest": + w.WriteHeader(999) + io.WriteString(w, `{"error": "boom"}`) + case "/v2/library/unknown/manifests/latest": + w.WriteHeader(404) + io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`) + default: + t.Logf("serving file: %s", r.URL.Path) + modelsHandler.ServeHTTP(w, r) + } + }) + + checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { + t.Helper() + + if got.Code != 200 { + t.Fatalf("Code = %d; want 200", got.Code) + } + gotlines := got.Body.String() + t.Logf("got:\n%s", gotlines) + for want := range strings.Lines(wantlines) { + want = strings.TrimSpace(want) + want, unwanted := strings.CutPrefix(want, "!") + want = strings.TrimSpace(want) + if !unwanted && !strings.Contains(gotlines, want) { + t.Fatalf("! missing %q in body", want) + } + if unwanted && strings.Contains(gotlines, want) { + t.Fatalf("! unexpected %q in body", want) + } + } + } + + got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} + `) + + got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} + {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} + {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} + {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} + {"status":"verifying layers"} + {"status":"writing manifest"} + {"status":"success"} + `) + + got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: model \"unknown\" not found"} + `) + + got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`) + checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed") + + got = s.send(t, "POST", "/api/pull", `!`) + checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value") + + got = s.send(t, "POST", "/api/pull", ``) + checkErrorResponse(t, got, 400, "bad_request", "empty request body") + + got = s.send(t, "POST", "/api/pull", `{"model": "://"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: invalid or missing name: \"\""} + + !verifying + !writing + !success + `) +} + func TestServerUnknownPath(t *testing.T) { - s := newTestServer(t) + s := newTestServer(t, nil) got := s.send(t, "DELETE", "/api/unknown", `{}`) checkErrorResponse(t, got, 404, "not_found", "not found") } diff --git a/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest b/server/internal/registry/testdata/models/manifests/example.com/library/smol/latest similarity index 100% rename from server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest rename to server/internal/registry/testdata/models/manifests/example.com/library/smol/latest diff --git a/server/internal/registry/testdata/registry.txt b/server/internal/registry/testdata/registry.txt new file mode 100644 index 000000000..2fc363fcb --- /dev/null +++ b/server/internal/registry/testdata/registry.txt @@ -0,0 +1,22 @@ +-- v2/library/smol/manifests/latest -- +{ + "schemaVersion": 2, + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356", + "size": 3 + }, + "layers": [ + { + "mediaType": "application/vnd.ollama.image.model", + "digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312", + "size": 5 + } + ] +} + +-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 -- +GGUF +-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 -- +{} diff --git a/server/prompt.go b/server/prompt.go index 233dffd69..5b5b958f1 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/template" @@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var imgData llm.ImageData if isMllama { - if envconfig.NewEngine() { + if len(m.ProjectorPaths) == 0 { imgData = llm.ImageData{ ID: len(images), Data: i, diff --git a/server/routes.go b/server/routes.go index ff42000f8..3efa12e43 100644 --- a/server/routes.go +++ b/server/routes.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -43,6 +42,12 @@ import ( "github.com/ollama/ollama/version" ) +func experimentEnabled(name string) bool { + return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) +} + +var useClient2 = experimentEnabled("client2") + var mode string = gin.DebugMode type Server struct { @@ -206,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama && !envconfig.NewEngine() { + if isMllama && len(model.ProjectorPaths) > 0 { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) @@ -1129,7 +1134,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } } -func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) { +func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { corsConfig := cors.DefaultConfig() corsConfig.AllowWildcard = true corsConfig.AllowBrowserExtensions = true @@ -1174,6 +1179,7 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha r.HEAD("/api/tags", s.ListHandler) r.GET("/api/tags", s.ListHandler) r.POST("/api/show", s.ShowHandler) + r.DELETE("/api/delete", s.DeleteHandler) // Create r.POST("/api/create", s.CreateHandler) @@ -1195,15 +1201,19 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) - // wrap old with new - rs := ®istry.Local{ - Cache: c, - Client: rc, - Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() - Fallback: r, + if rc != nil { + // wrap old with new + rs := ®istry.Local{ + Client: rc, + Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() + Fallback: r, + + Prune: PruneLayers, + } + return rs, nil } - return rs, nil + return r, nil } func Serve(ln net.Listener) error { @@ -1258,19 +1268,20 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} - c, err := ollama.DefaultCache() - if err != nil { - return err + var rc *ollama.Registry + if useClient2 { + var err error + rc, err = ollama.DefaultRegistry() + if err != nil { + return err + } } - rc, err := ollama.DefaultRegistry() + + h, err := s.GenerateRoutes(rc) if err != nil { return err } - h, err := s.GenerateRoutes(c, rc) - if err != nil { - return err - } http.Handle("/", h) ctx, done := context.WithCancel(context.Background()) diff --git a/server/routes_test.go b/server/routes_test.go index 0dd782f4f..e13c4b599 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -23,7 +23,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) { modelsDir := t.TempDir() t.Setenv("OLLAMA_MODELS", modelsDir) - c, err := blob.Open(modelsDir) - if err != nil { - t.Fatalf("failed to open models dir: %v", err) - } - rc := &ollama.Registry{ // This is a temporary measure to allow us to move forward, // surfacing any code contacting ollama.com we do not intended @@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) { } s := &Server{} - router, err := s.GenerateRoutes(c, rc) + router, err := s.GenerateRoutes(rc) if err != nil { t.Fatalf("failed to generate routes: %v", err) }