mirror of https://github.com/ollama/ollama.git
Merge branch 'ollama:main' into mmap
This commit is contained in:
commit
a374fbde4d
|
@ -51,6 +51,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||||
|
|
||||||
|
add_compile_definitions(NDEBUG)
|
||||||
|
|
||||||
set(GGML_CPU ON)
|
set(GGML_CPU ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||||
|
|
|
@ -406,6 +406,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||||
|
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
|
@ -449,6 +450,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||||
|
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
|
|
|
@ -1236,11 +1236,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := client.Heartbeat(cmd.Context()); err != nil {
|
if err := client.Heartbeat(cmd.Context()); err != nil {
|
||||||
if !strings.Contains(err.Error(), " refused") {
|
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := startApp(cmd.Context(), client); err != nil {
|
if err := startApp(cmd.Context(), client); err != nil {
|
||||||
return errors.New("could not connect to ollama app, is it running?")
|
return fmt.Errorf("ollama server not responding - %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -4,17 +4,27 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Installer = "OllamaSetup.exe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startApp(ctx context.Context, client *api.Client) error {
|
func startApp(ctx context.Context, client *api.Client) error {
|
||||||
// log.Printf("XXX Attempting to find and start ollama app")
|
if len(isProcRunning(Installer)) > 0 {
|
||||||
|
return fmt.Errorf("upgrade in progress...")
|
||||||
|
}
|
||||||
AppName := "ollama app.exe"
|
AppName := "ollama app.exe"
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,3 +66,41 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||||
}
|
}
|
||||||
return waitForServer(ctx, client)
|
return waitForServer(ctx, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isProcRunning(procName string) []uint32 {
|
||||||
|
pids := make([]uint32, 2048)
|
||||||
|
var ret uint32
|
||||||
|
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||||
|
slog.Debug("failed to check for running installers", "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pids = pids[:ret]
|
||||||
|
var matches []uint32
|
||||||
|
for _, pid := range pids {
|
||||||
|
if pid == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(hProcess)
|
||||||
|
var module windows.Handle
|
||||||
|
var cbNeeded uint32
|
||||||
|
cb := (uint32)(unsafe.Sizeof(module))
|
||||||
|
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var sz uint32 = 1024 * 8
|
||||||
|
moduleName := make([]uint16, sz)
|
||||||
|
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
|
||||||
|
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
|
||||||
|
if strings.EqualFold(exeFile, procName) {
|
||||||
|
matches = append(matches, pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matches
|
||||||
|
}
|
||||||
|
|
|
@ -94,7 +94,9 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
var text []Tensor
|
var text []Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if t.Name() == "v.position_embd.gate" {
|
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
|
||||||
|
text = append(text, t)
|
||||||
|
} else if t.Name() == "v.position_embd.gate" {
|
||||||
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
|
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
|
||||||
tt := t.Clone()
|
tt := t.Clone()
|
||||||
tt.SetRepacker(m.repack(name))
|
tt.SetRepacker(m.repack(name))
|
||||||
|
@ -105,23 +107,21 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
WriterTo: tt,
|
WriterTo: tt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
|
||||||
t.SetRepacker(m.repack(t.Name()))
|
|
||||||
out = append(out, &ggml.Tensor{
|
|
||||||
Name: t.Name(),
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: t.Shape(),
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
} else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
|
||||||
out = append(out, &ggml.Tensor{
|
|
||||||
Name: t.Name(),
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: t.Shape(),
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
text = append(text, t)
|
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
||||||
|
t.SetRepacker(m.repack(t.Name()))
|
||||||
|
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||||
|
t.SetRepacker(m.repack(t.Name()))
|
||||||
|
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
|
||||||
|
t.SetRepacker(m.repack(t.Name()))
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,6 +137,24 @@ func (m *mllamaModel) repack(name string) Repacker {
|
||||||
|
|
||||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
|
||||||
|
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
|
||||||
|
heads := m.VisionModel.AttentionHeads
|
||||||
|
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
t, err = tensor.Tanh(t)
|
t, err = tensor.Tanh(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -148,6 +166,7 @@ func (m *mllamaModel) repack(name string) Repacker {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t = tensor.Materialize(t)
|
t = tensor.Materialize(t)
|
||||||
// flatten tensor so it can be return as a vector
|
// flatten tensor so it can be return as a vector
|
||||||
|
|
|
@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { r.Close() })
|
t.Cleanup(func() { r.Close() })
|
||||||
|
|
||||||
m, _, err := ggml.Decode(r, -1)
|
m, err := ggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
m, _, err := ggml.Decode(r, -1)
|
m, err := ggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,22 +132,12 @@ success
|
||||||
|
|
||||||
### Supported Quantizations
|
### Supported Quantizations
|
||||||
|
|
||||||
- `q4_0`
|
|
||||||
- `q4_1`
|
|
||||||
- `q5_0`
|
|
||||||
- `q5_1`
|
|
||||||
- `q8_0`
|
- `q8_0`
|
||||||
|
|
||||||
#### K-means Quantizations
|
#### K-means Quantizations
|
||||||
|
|
||||||
- `q3_K_S`
|
|
||||||
- `q3_K_M`
|
|
||||||
- `q3_K_L`
|
|
||||||
- `q4_K_S`
|
- `q4_K_S`
|
||||||
- `q4_K_M`
|
- `q4_K_M`
|
||||||
- `q5_K_S`
|
|
||||||
- `q5_K_M`
|
|
||||||
- `q6_K`
|
|
||||||
|
|
||||||
|
|
||||||
## Sharing your model on ollama.com
|
## Sharing your model on ollama.com
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
container
|
container
|
||||||
model
|
model
|
||||||
|
Length int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type model interface {
|
type model interface {
|
||||||
|
@ -386,12 +387,12 @@ func DetectContentType(b []byte) string {
|
||||||
//
|
//
|
||||||
// It collects array values for arrays with a size less than or equal to
|
// It collects array values for arrays with a size less than or equal to
|
||||||
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
|
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
|
||||||
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||||
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||||
|
|
||||||
var magic uint32
|
var magic uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var c container
|
var c container
|
||||||
|
@ -401,24 +402,25 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||||
case FILE_MAGIC_GGUF_BE:
|
case FILE_MAGIC_GGUF_BE:
|
||||||
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
||||||
default:
|
default:
|
||||||
return nil, 0, errors.New("invalid file magic")
|
return nil, errors.New("invalid file magic")
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := c.Decode(rs)
|
model, err := c.Decode(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// final model type
|
// final model type
|
||||||
return &GGML{
|
return &GGML{
|
||||||
container: c,
|
container: c,
|
||||||
model: model,
|
model: model,
|
||||||
}, offset, nil
|
Length: offset,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
|
|
|
@ -35,7 +35,7 @@ func TestWriteGGUF(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
ff, _, err := Decode(r, 0)
|
ff, err := Decode(r, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) {
|
||||||
}
|
}
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
model: "llava:7b",
|
model: "qwen2.5vl",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
model: "llama3.2-vision",
|
model: "llama3.2-vision",
|
||||||
|
@ -60,6 +60,7 @@ func TestVisionModels(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationSplitBatch(t *testing.T) {
|
func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
|
skipUnderMinVRAM(t, 6)
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -30,6 +30,11 @@ type Causal struct {
|
||||||
|
|
||||||
// ** current forward pass **
|
// ** current forward pass **
|
||||||
|
|
||||||
|
// curReserve indicates that this forward pass is only for
|
||||||
|
// memory reservation and we should not update our metadata
|
||||||
|
// based on it.
|
||||||
|
curReserve bool
|
||||||
|
|
||||||
// the active layer for Get and Put
|
// the active layer for Get and Put
|
||||||
curLayer int
|
curLayer int
|
||||||
|
|
||||||
|
@ -159,12 +164,13 @@ func (c *Causal) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
c.curReserve = reserve
|
||||||
c.curBatchSize = len(batch.Positions)
|
c.curBatchSize = len(batch.Positions)
|
||||||
c.curSequences = batch.Sequences
|
c.curSequences = batch.Sequences
|
||||||
c.curPositions = batch.Positions
|
c.curPositions = batch.Positions
|
||||||
c.opts.Except = nil
|
c.opts.Except = nil
|
||||||
|
|
||||||
if !reserve {
|
if !c.curReserve {
|
||||||
c.updateSlidingWindow()
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
@ -211,10 +217,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||||
c.curCellRange.max = len(c.cells) - 1
|
c.curCellRange.max = len(c.cells) - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
c.curMask = c.buildMask(ctx)
|
||||||
c.curMask, err = c.buildMask(ctx)
|
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRange() cellRange {
|
func newRange() cellRange {
|
||||||
|
@ -297,7 +302,7 @@ func roundUp(length, pad int) int {
|
||||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||||
// Align and pad the two dimensions as required by the backend
|
// Align and pad the two dimensions as required by the backend
|
||||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||||
|
|
||||||
|
@ -305,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
|
||||||
|
if c.curReserve {
|
||||||
|
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||||
|
}
|
||||||
|
|
||||||
mask := make([]float32, batchSize*length)
|
mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
|
@ -325,10 +335,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||||
mask[i] = float32(math.Inf(-1))
|
mask[i] = float32(math.Inf(-1))
|
||||||
}
|
}
|
||||||
|
|
||||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.MaskDType != ml.DTypeF32 {
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||||
|
@ -336,7 +343,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||||
maskTensor = out
|
maskTensor = out
|
||||||
}
|
}
|
||||||
|
|
||||||
return maskTensor, nil
|
return maskTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
|
@ -491,12 +498,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||||
c.opts = opts
|
c.opts = opts
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
var err error
|
c.curMask = c.buildMask(ctx)
|
||||||
c.curMask, err = c.buildMask(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// This error should never occur because we have previously built a mask with the same shape
|
|
||||||
panic(fmt.Errorf("SetCausal: %w", err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -652,10 +654,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, key := range c.keys {
|
for i, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
|
|
|
@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.SetLayer(0)
|
cache.SetLayer(0)
|
||||||
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
|
tensor := context.FromFloatSlice(test.in, test.inShape...)
|
||||||
cache.Put(context, tensor, tensor)
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
out, _, mask := cache.Get(context)
|
out, _, mask := cache.Get(context)
|
||||||
|
@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.SetLayer(0)
|
cache.SetLayer(0)
|
||||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||||
cache.Put(context, tensor, tensor)
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
// with window size 4, nothing has slid out of the window yet
|
// with window size 4, nothing has slid out of the window yet
|
||||||
|
@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.SetLayer(0)
|
cache.SetLayer(0)
|
||||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||||
cache.Put(context, tensor, tensor)
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
// only the latest position has overlapping windows
|
// only the latest position has overlapping windows
|
||||||
|
@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
return c.Empty(dtype, shape...)
|
return c.Empty(dtype, shape...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
|
||||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||||
|
|
||||||
copy(t.data, s)
|
copy(t.data, s)
|
||||||
|
|
||||||
return t, nil
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
|
||||||
f := make([]float32, len(s))
|
f := make([]float32, len(s))
|
||||||
for i := range f {
|
for i := range f {
|
||||||
f[i] = float32(s[i])
|
f[i] = float32(s[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
out, _ := c.FromFloatSlice(f, shape...)
|
out := c.FromFloatSlice(f, shape...)
|
||||||
out.(*testTensor).dtype = ml.DTypeI32
|
out.(*testTensor).dtype = ml.DTypeI32
|
||||||
|
|
||||||
return out, nil
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||||
|
@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
|
||||||
s = append(s, i)
|
s = append(s, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
out, _ := c.FromFloatSlice(s, len(s))
|
out := c.FromFloatSlice(s, len(s))
|
||||||
out.(*testTensor).dtype = dtype
|
out.(*testTensor).dtype = dtype
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|
||||||
func (c *testContext) Compute(...ml.Tensor) {}
|
func (c *testContext) Compute(...ml.Tensor) {}
|
||||||
|
|
||||||
func (c *testContext) Reserve() error { return nil }
|
func (c *testContext) Reserve() {}
|
||||||
|
|
||||||
func (c *testContext) MaxGraphNodes() int {
|
func (c *testContext) MaxGraphNodes() int {
|
||||||
return 10
|
return 10
|
||||||
|
|
|
@ -544,7 +544,7 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext,
|
||||||
cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
|
cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
|
||||||
cparams.penalty_repeat = C.float(params.PenaltyRepeat)
|
cparams.penalty_repeat = C.float(params.PenaltyRepeat)
|
||||||
cparams.penalty_freq = C.float(params.PenaltyFreq)
|
cparams.penalty_freq = C.float(params.PenaltyFreq)
|
||||||
cparams.penalty_present = C.float(params.PenaltyFreq)
|
cparams.penalty_present = C.float(params.PenaltyPresent)
|
||||||
cparams.seed = C.uint32_t(params.Seed)
|
cparams.seed = C.uint32_t(params.Seed)
|
||||||
|
|
||||||
grammar := C.CString(params.Grammar)
|
grammar := C.CString(params.Grammar)
|
||||||
|
@ -580,7 +580,7 @@ func SchemaToGrammar(schema []byte) []byte {
|
||||||
defer C.free(unsafe.Pointer(cStr))
|
defer C.free(unsafe.Pointer(cStr))
|
||||||
|
|
||||||
// Allocate buffer for grammar based on schema length but with upper bound
|
// Allocate buffer for grammar based on schema length but with upper bound
|
||||||
maxLen := min(1024*1024, len(schema)*4)
|
maxLen := max(32768, min(1024*1024, len(schema)*4))
|
||||||
buf := make([]byte, maxLen)
|
buf := make([]byte, maxLen)
|
||||||
|
|
||||||
// Call C function to convert schema to grammar
|
// Call C function to convert schema to grammar
|
||||||
|
|
|
@ -0,0 +1,156 @@
|
||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jesse Gross <jesse@ollama.com>
|
||||||
|
Date: Fri, 18 Apr 2025 15:58:19 -0700
|
||||||
|
Subject: [PATCH] graph memory reporting on failure
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/include/ggml-alloc.h | 6 ++++++
|
||||||
|
ggml/include/ggml-backend.h | 6 ++++++
|
||||||
|
ggml/src/ggml-alloc.c | 38 +++++++++++++++++++++++++++++++++----
|
||||||
|
ggml/src/ggml-backend.cpp | 10 ++++++++++
|
||||||
|
4 files changed, 56 insertions(+), 4 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h
|
||||||
|
index 2cb150fd..781b1e10 100644
|
||||||
|
--- a/ggml/include/ggml-alloc.h
|
||||||
|
+++ b/ggml/include/ggml-alloc.h
|
||||||
|
@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph
|
||||||
|
|
||||||
|
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||||
|
|
||||||
|
+struct ggml_allocr_buffer_status {
|
||||||
|
+ size_t size;
|
||||||
|
+ bool allocated;
|
||||||
|
+};
|
||||||
|
+GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||||
|
+
|
||||||
|
// Utils
|
||||||
|
// Create a buffer and allocate all the tensors in a ggml_context
|
||||||
|
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||||
|
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||||
|
index 778927f6..74e46716 100644
|
||||||
|
--- a/ggml/include/ggml-backend.h
|
||||||
|
+++ b/ggml/include/ggml-backend.h
|
||||||
|
@@ -304,6 +304,12 @@ extern "C" {
|
||||||
|
|
||||||
|
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
|
||||||
|
+ struct ggml_backend_buffer_status {
|
||||||
|
+ size_t size;
|
||||||
|
+ bool allocated;
|
||||||
|
+ };
|
||||||
|
+ GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
+
|
||||||
|
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
|
||||||
|
index 5fd379f6..04812990 100644
|
||||||
|
--- a/ggml/src/ggml-alloc.c
|
||||||
|
+++ b/ggml/src/ggml-alloc.c
|
||||||
|
@@ -364,6 +364,7 @@ struct node_alloc {
|
||||||
|
struct ggml_gallocr {
|
||||||
|
ggml_backend_buffer_type_t * bufts; // [n_buffers]
|
||||||
|
ggml_backend_buffer_t * buffers; // [n_buffers]
|
||||||
|
+ size_t *buffer_sizes; // [n_buffers]
|
||||||
|
struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
|
||||||
|
int n_buffers;
|
||||||
|
|
||||||
|
@@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
|
||||||
|
galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
|
||||||
|
GGML_ASSERT(galloc->buffers != NULL);
|
||||||
|
|
||||||
|
+ galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t));
|
||||||
|
+ GGML_ASSERT(galloc->buffer_sizes != NULL);
|
||||||
|
+
|
||||||
|
galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
|
||||||
|
GGML_ASSERT(galloc->buf_tallocs != NULL);
|
||||||
|
|
||||||
|
@@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
|
||||||
|
ggml_hash_set_free(&galloc->hash_set);
|
||||||
|
free(galloc->hash_values);
|
||||||
|
free(galloc->bufts);
|
||||||
|
+ free(galloc->buffer_sizes);
|
||||||
|
free(galloc->buffers);
|
||||||
|
free(galloc->buf_tallocs);
|
||||||
|
free(galloc->node_allocs);
|
||||||
|
@@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
+ bool success = true;
|
||||||
|
+
|
||||||
|
// reallocate buffers if needed
|
||||||
|
for (int i = 0; i < galloc->n_buffers; i++) {
|
||||||
|
// if the buffer type is used multiple times, we reuse the same buffer
|
||||||
|
@@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
|
|
||||||
|
ggml_backend_buffer_free(galloc->buffers[i]);
|
||||||
|
galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
|
||||||
|
- if (galloc->buffers[i] == NULL) {
|
||||||
|
+ if (galloc->buffers[i]) {
|
||||||
|
+ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||||
|
+ ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||||
|
+ } else {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
||||||
|
- return false;
|
||||||
|
+ galloc->buffer_sizes[i] = new_size;
|
||||||
|
+ success = false;
|
||||||
|
}
|
||||||
|
- ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||||
|
+ } else {
|
||||||
|
+ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- return true;
|
||||||
|
+ return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
||||||
|
@@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||||
|
return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
+struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||||
|
+ GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
|
||||||
|
+
|
||||||
|
+ for (int i = 0; i < buffer_id; i++) {
|
||||||
|
+ if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) {
|
||||||
|
+ // This buffer is the same as a previous one due to the same buffer type being used multiple times
|
||||||
|
+ // (See above.) However, we need a different check because multiple buffers might be NULL in our
|
||||||
|
+ // case and we still want to know the attempted size.
|
||||||
|
+
|
||||||
|
+ struct ggml_allocr_buffer_status status = {0, true};
|
||||||
|
+ return status;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL};
|
||||||
|
+ return status;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// utils
|
||||||
|
|
||||||
|
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
||||||
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
|
index 0ce73a99..be335e8c 100644
|
||||||
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
|
@@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||||
|
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
+struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
||||||
|
+ int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||||
|
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||||
|
+
|
||||||
|
+ struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index);
|
||||||
|
+ struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated};
|
||||||
|
+
|
||||||
|
+ return status;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||||
|
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||||
|
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
|
@ -1,12 +1,9 @@
|
||||||
package llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -85,8 +82,11 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
var graphOffload uint64
|
var graphOffload uint64
|
||||||
|
|
||||||
// Projectors loaded into GPU0 only
|
// Projectors loaded into GPU0 only
|
||||||
var projectorWeights uint64
|
var llamaEngineProjectorWeights uint64
|
||||||
var projectorGraph uint64
|
|
||||||
|
// Projectors loaded with output layer
|
||||||
|
var ollamaEngineProjectorWeights uint64
|
||||||
|
var ollamaEngineProjectorGraph uint64
|
||||||
|
|
||||||
// Conditional output size on GPU 0
|
// Conditional output size on GPU 0
|
||||||
var memoryLayerOutput uint64
|
var memoryLayerOutput uint64
|
||||||
|
@ -111,21 +111,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
|
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
|
||||||
|
|
||||||
for _, projector := range projectors {
|
for _, projector := range projectors {
|
||||||
weight := projectorMemoryRequirements(projector)
|
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
|
||||||
projectorWeights += weight
|
|
||||||
|
|
||||||
// multimodal models require at least 2048 context
|
// multimodal models require at least 2048 context
|
||||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||||
}
|
}
|
||||||
if projectorWeights == 0 && projectorGraph == 0 {
|
if llamaEngineProjectorWeights == 0 {
|
||||||
projectorWeights, projectorGraph = f.VisionGraphSize()
|
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
|
||||||
|
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||||
}
|
}
|
||||||
|
|
||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
// add one layer (chosing the max layer) worth of memory as a buffer
|
// add one layer worth of memory as a buffer
|
||||||
layerSize = slices.MaxFunc(slices.Collect(maps.Values(layers)), func(a, b ggml.Layer) int {
|
if blk0, ok := layers["blk.0"]; ok {
|
||||||
return cmp.Compare(a.Size(), b.Size())
|
layerSize = blk0.Size()
|
||||||
}).Size()
|
} else {
|
||||||
|
slog.Warn("model missing blk.0 layer size")
|
||||||
|
}
|
||||||
|
|
||||||
var kvct string
|
var kvct string
|
||||||
if envconfig.FlashAttention() &&
|
if envconfig.FlashAttention() &&
|
||||||
|
@ -163,6 +165,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Output layer handled at the end if we have space
|
||||||
if layer, ok := layers["output_norm"]; ok {
|
if layer, ok := layers["output_norm"]; ok {
|
||||||
memoryLayerOutput += layer.Size()
|
memoryLayerOutput += layer.Size()
|
||||||
}
|
}
|
||||||
|
@ -172,8 +175,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
memoryLayerOutput += layer.Size()
|
memoryLayerOutput += layer.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output layer handled at the end if we have space
|
gpuZeroOverhead := llamaEngineProjectorWeights
|
||||||
gpuZeroOverhead := projectorWeights + projectorGraph
|
|
||||||
|
|
||||||
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
|
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
|
||||||
var layerCount int
|
var layerCount int
|
||||||
|
@ -216,6 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
if len(gpusWithSpace) > 0 {
|
if len(gpusWithSpace) > 0 {
|
||||||
gpuZeroID = gpusWithSpace[0].i
|
gpuZeroID = gpusWithSpace[0].i
|
||||||
gpuAllocations[gpuZeroID] += gpuZeroOverhead
|
gpuAllocations[gpuZeroID] += gpuZeroOverhead
|
||||||
|
} else {
|
||||||
|
overflow += gpuZeroOverhead
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all the layers, find where they can fit on the GPU(s)
|
// For all the layers, find where they can fit on the GPU(s)
|
||||||
|
@ -256,21 +260,24 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if we need to consider output then find where it fits
|
// Determine if we need to consider output then find where it fits
|
||||||
if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) {
|
memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph
|
||||||
|
if memoryLastLayer > 0 {
|
||||||
|
if opts.NumGPU < 0 || layerCount < opts.NumGPU {
|
||||||
for j := len(gpusWithSpace); j > 0; j-- {
|
for j := len(gpusWithSpace); j > 0; j-- {
|
||||||
g := gpusWithSpace[layerCount%j]
|
g := gpusWithSpace[layerCount%j]
|
||||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||||
if g.g.FreeMemory > overhead+used+memoryLayerOutput {
|
if g.g.FreeMemory > overhead+used+memoryLastLayer {
|
||||||
gpuAllocations[g.i] += memoryLayerOutput
|
gpuAllocations[g.i] += memoryLastLayer
|
||||||
layerCounts[g.i]++
|
layerCounts[g.i]++
|
||||||
layerCount++
|
layerCount++
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if layerCount < int(f.KV().BlockCount())+1 {
|
if layerCount < int(f.KV().BlockCount())+1 {
|
||||||
fullyLoaded = false
|
fullyLoaded = false
|
||||||
overflow += memoryLayerOutput
|
overflow += memoryLastLayer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,8 +335,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
memoryLayerOutput: memoryLayerOutput,
|
memoryLayerOutput: memoryLayerOutput,
|
||||||
graphFullOffload: graphFullOffload,
|
graphFullOffload: graphFullOffload,
|
||||||
graphPartialOffload: graphPartialOffload,
|
graphPartialOffload: graphPartialOffload,
|
||||||
projectorWeights: projectorWeights,
|
projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights,
|
||||||
projectorGraph: projectorGraph,
|
projectorGraph: ollamaEngineProjectorGraph,
|
||||||
}
|
}
|
||||||
|
|
||||||
if gpus[0].Library == "cpu" {
|
if gpus[0].Library == "cpu" {
|
||||||
|
@ -415,7 +422,7 @@ func projectorMemoryRequirements(filename string) (weights uint64) {
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
ggml, _, err := ggml.Decode(file, 1024)
|
ggml, err := ggml.Decode(file, 1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
ggml, _, err := ggml.Decode(f, maxArraySize)
|
ggml, err := ggml.Decode(f, maxArraySize)
|
||||||
return ggml, err
|
return ggml, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
157
ml/backend.go
157
ml/backend.go
|
@ -5,8 +5,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -15,6 +15,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Backend interface {
|
type Backend interface {
|
||||||
|
Load(ctx context.Context, progress func(float32)) error
|
||||||
|
|
||||||
|
// BackendMemory returns the memory allocations that were made for this model
|
||||||
|
BackendMemory() BackendMemory
|
||||||
|
|
||||||
Config() fs.Config
|
Config() fs.Config
|
||||||
Get(name string) Tensor
|
Get(name string) Tensor
|
||||||
NewContext() Context
|
NewContext() Context
|
||||||
|
@ -52,10 +57,6 @@ type CacheConfig struct {
|
||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
type BackendParams struct {
|
type BackendParams struct {
|
||||||
// Progress is a callback function that allows reporting percentage completion
|
|
||||||
// of model loading
|
|
||||||
Progress func(float32)
|
|
||||||
|
|
||||||
// NumThreads sets the number of threads to use if running on the CPU
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
NumThreads int
|
NumThreads int
|
||||||
|
|
||||||
|
@ -72,9 +73,122 @@ type BackendParams struct {
|
||||||
FlashAttention bool
|
FlashAttention bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
|
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||||
|
// the attempted memory allocation.
|
||||||
|
type ErrNoMem struct {
|
||||||
|
BackendMemory
|
||||||
|
}
|
||||||
|
|
||||||
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
|
func (e ErrNoMem) Error() string {
|
||||||
|
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AllocationStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Unallocated memory - have not yet attempted to allocate
|
||||||
|
Unallocated AllocationStatus = iota
|
||||||
|
|
||||||
|
// Failed memory - tried to allocate the memory and did not succeed
|
||||||
|
Failed
|
||||||
|
|
||||||
|
// Allocated memory = tried and succeeded to allocate memory
|
||||||
|
Allocated
|
||||||
|
)
|
||||||
|
|
||||||
|
// Memory is the size of an allocation and whether it was successful.
|
||||||
|
type Memory struct {
|
||||||
|
Size uint64
|
||||||
|
Status AllocationStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Memory) String() string {
|
||||||
|
s := fmt.Sprint(m.Size)
|
||||||
|
|
||||||
|
switch m.Status {
|
||||||
|
case Unallocated:
|
||||||
|
s += "U"
|
||||||
|
case Failed:
|
||||||
|
s += "F"
|
||||||
|
case Allocated:
|
||||||
|
s += "A"
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceMemory provides a breakdown of the memory needed
|
||||||
|
// per device, such as a CPU or GPU.
|
||||||
|
type DeviceMemory struct {
|
||||||
|
// Name is the name of the device as labeled by the backend. It
|
||||||
|
// may not be persistent across instances of the runner.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Weights is the per-layer memory needed for the model weights.
|
||||||
|
Weights []Memory
|
||||||
|
|
||||||
|
// Cache is the per-layer memory needed for the KV cache.
|
||||||
|
Cache []Memory
|
||||||
|
|
||||||
|
// Graph is the size of the compute graph. It is not per-layer.
|
||||||
|
Graph Memory
|
||||||
|
}
|
||||||
|
|
||||||
|
func memoryPresent(mem []Memory) bool {
|
||||||
|
return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 })
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DeviceMemory) LogValue() slog.Value {
|
||||||
|
var attrs []slog.Attr
|
||||||
|
if memoryPresent(m.Weights) {
|
||||||
|
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||||
|
}
|
||||||
|
|
||||||
|
if memoryPresent(m.Cache) {
|
||||||
|
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Graph.Size != 0 {
|
||||||
|
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.GroupValue(attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendMemory provides the amount of memory required to load the model
|
||||||
|
// per device based on the BackendParams. In some cases, not all required
|
||||||
|
// allocations will be known at this point. However, the size of the most recent
|
||||||
|
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||||
|
// accommodate that to make forward progress.
|
||||||
|
type BackendMemory struct {
|
||||||
|
// InputsWeights are always located on the CPU and cannot be moved
|
||||||
|
InputWeights Memory
|
||||||
|
|
||||||
|
// CPU model components are located in system memory. This does not
|
||||||
|
// include unified memory allocated through the GPU.
|
||||||
|
CPU DeviceMemory
|
||||||
|
|
||||||
|
// GPU model components are located on one or more GPUs.
|
||||||
|
GPUs []DeviceMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m BackendMemory) LogValue() slog.Value {
|
||||||
|
var attrs []slog.Attr
|
||||||
|
if m.InputWeights.Size != 0 {
|
||||||
|
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||||
|
for _, g := range m.GPUs {
|
||||||
|
attrs = append(attrs, slog.Any(g.Name, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.GroupValue(attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||||
|
|
||||||
|
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||||
if _, ok := backends[name]; ok {
|
if _, ok := backends[name]; ok {
|
||||||
panic("backend: backend already registered")
|
panic("backend: backend already registered")
|
||||||
}
|
}
|
||||||
|
@ -82,9 +196,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam
|
||||||
backends[name] = f
|
backends[name] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
|
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||||
if backend, ok := backends["ggml"]; ok {
|
if backend, ok := backends["ggml"]; ok {
|
||||||
return backend(ctx, f, params)
|
return backend(modelPath, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
|
@ -93,8 +207,8 @@ func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend,
|
||||||
type Context interface {
|
type Context interface {
|
||||||
Empty(dtype DType, shape ...int) Tensor
|
Empty(dtype DType, shape ...int) Tensor
|
||||||
Zeros(dtype DType, shape ...int) Tensor
|
Zeros(dtype DType, shape ...int) Tensor
|
||||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
FromFloatSlice(s []float32, shape ...int) Tensor
|
||||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
FromIntSlice(s []int32, shape ...int) Tensor
|
||||||
|
|
||||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||||
Arange(start, stop, step float32, dtype DType) Tensor
|
Arange(start, stop, step float32, dtype DType) Tensor
|
||||||
|
@ -106,7 +220,7 @@ type Context interface {
|
||||||
// graph, simply preallocates memory. Typically called with a
|
// graph, simply preallocates memory. Typically called with a
|
||||||
// worst case graph to ensure all resources are available for
|
// worst case graph to ensure all resources are available for
|
||||||
// for future inference.
|
// for future inference.
|
||||||
Reserve() error
|
Reserve()
|
||||||
|
|
||||||
MaxGraphNodes() int
|
MaxGraphNodes() int
|
||||||
Close()
|
Close()
|
||||||
|
@ -119,21 +233,6 @@ type Context interface {
|
||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// RopeOptions contains optional parameters for RoPE function
|
|
||||||
type RopeOptions struct {
|
|
||||||
OriginalContextLen uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// RopeOption defines a function that modifies RopeOpts
|
|
||||||
type RopeOption func(*RopeOptions)
|
|
||||||
|
|
||||||
// WithContextLen sets a custom context length
|
|
||||||
func WithContextLen(len uint32) RopeOption {
|
|
||||||
return func(opts *RopeOptions) {
|
|
||||||
opts.OriginalContextLen = len
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tensor interface {
|
type Tensor interface {
|
||||||
Dim(n int) int
|
Dim(n int) int
|
||||||
Stride(n int) int
|
Stride(n int) int
|
||||||
|
@ -147,6 +246,8 @@ type Tensor interface {
|
||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Mul(ctx Context, t2 Tensor) Tensor
|
Mul(ctx Context, t2 Tensor) Tensor
|
||||||
|
Div(ctx Context, t2 Tensor) Tensor
|
||||||
|
|
||||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||||
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||||
|
@ -155,11 +256,11 @@ type Tensor interface {
|
||||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
SumRows(ctx Context) Tensor
|
||||||
|
|
||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
|
|
||||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
Sin(ctx Context) Tensor
|
Sin(ctx Context) Tensor
|
||||||
|
|
|
@ -10,7 +10,6 @@ import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
@ -30,6 +29,7 @@ import (
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,8 +44,15 @@ func devices() []*C.struct_ggml_backend_device {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
|
// modelPath is the location of the model data
|
||||||
|
modelPath string
|
||||||
|
|
||||||
meta *fsggml.GGML
|
meta *fsggml.GGML
|
||||||
|
|
||||||
|
// tensorLoadTargets maps from the name of the tensor in the file
|
||||||
|
// to the name that is used by the model definition
|
||||||
|
tensorLoadTargets map[string][]string
|
||||||
|
|
||||||
sched *C.struct_ggml_backend_sched
|
sched *C.struct_ggml_backend_sched
|
||||||
schedBackends []*C.struct_ggml_backend
|
schedBackends []*C.struct_ggml_backend
|
||||||
schedBufts []*C.struct_ggml_backend_buffer_type
|
schedBufts []*C.struct_ggml_backend_buffer_type
|
||||||
|
@ -58,14 +65,26 @@ type Backend struct {
|
||||||
// layers is the backend used for repeating layers
|
// layers is the backend used for repeating layers
|
||||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
|
// requiredMemory is the cumulative memory allocations needed by the backend
|
||||||
|
requiredMemory *ml.BackendMemory
|
||||||
|
|
||||||
|
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||||
|
btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory
|
||||||
|
|
||||||
flashAttention bool
|
flashAttention bool
|
||||||
|
|
||||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||||
meta, n, err := fsggml.Decode(r, -1)
|
r, err := os.Open(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
meta, err := fsggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -80,6 +99,9 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
"num_key_values", len(meta.KV()),
|
"num_key_values", len(meta.KV()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var requiredMemory ml.BackendMemory
|
||||||
|
btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory)
|
||||||
|
|
||||||
type deviceBufferType struct {
|
type deviceBufferType struct {
|
||||||
d *C.struct_ggml_backend_device
|
d *C.struct_ggml_backend_device
|
||||||
bts []*C.struct_ggml_backend_buffer_type
|
bts []*C.struct_ggml_backend_buffer_type
|
||||||
|
@ -100,6 +122,8 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
blocks := int(meta.KV().BlockCount())
|
||||||
|
|
||||||
// create list of buffer types for the cpu
|
// create list of buffer types for the cpu
|
||||||
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
|
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
|
||||||
for _, d := range append(accels, append(gpus, cpus...)...) {
|
for _, d := range append(accels, append(gpus, cpus...)...) {
|
||||||
|
@ -107,17 +131,27 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||||
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||||
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
|
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
|
||||||
|
btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||||
|
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||||
|
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||||
|
|
||||||
// create list of buffer types for each gpu
|
// create list of buffer types for each gpu
|
||||||
var gpuDeviceBufferTypes []deviceBufferType
|
var gpuDeviceBufferTypes []deviceBufferType
|
||||||
for _, d := range gpus {
|
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
|
||||||
|
for i, d := range gpus {
|
||||||
bt := C.ggml_backend_dev_buffer_type(d)
|
bt := C.ggml_backend_dev_buffer_type(d)
|
||||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
||||||
d: d,
|
d: d,
|
||||||
bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
|
bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
|
||||||
})
|
})
|
||||||
|
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||||
|
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||||
|
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||||
|
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
useDefaultSplit := true
|
useDefaultSplit := true
|
||||||
|
@ -156,8 +190,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
// inputs always use cpu
|
// inputs always use cpu
|
||||||
input := cpuDeviceBufferType
|
input := cpuDeviceBufferType
|
||||||
|
|
||||||
blocks := int(meta.KV().BlockCount())
|
|
||||||
|
|
||||||
// define a range of gpu layers. anything outside of this range is assigned to the cpu
|
// define a range of gpu layers. anything outside of this range is assigned to the cpu
|
||||||
gpuRangeStart := max(0, blocks-params.NumGPULayers)
|
gpuRangeStart := max(0, blocks-params.NumGPULayers)
|
||||||
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
|
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
|
||||||
|
@ -198,7 +230,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
|
|
||||||
// contexts are shared by tensors of the same buffer type
|
// contexts are shared by tensors of the same buffer type
|
||||||
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
|
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
|
||||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
|
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor {
|
||||||
for _, bt := range bts {
|
for _, bt := range bts {
|
||||||
if _, ok := ctxs[bt]; !ok {
|
if _, ok := ctxs[bt]; !ok {
|
||||||
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
|
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
|
||||||
|
@ -224,6 +256,16 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
C.ggml_set_name(tt, cname)
|
C.ggml_set_name(tt, cname)
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||||
|
|
||||||
|
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||||
|
if layer == -1 {
|
||||||
|
// Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case
|
||||||
|
requiredMemory.InputWeights.Status = ml.Allocated
|
||||||
|
requiredMemory.InputWeights.Size += uint64(size)
|
||||||
|
} else {
|
||||||
|
btDeviceMemory[bt].Weights[layer].Size += uint64(size)
|
||||||
|
}
|
||||||
|
|
||||||
//nolint:staticcheck // TODO: check if buffer type supports this tensor
|
//nolint:staticcheck // TODO: check if buffer type supports this tensor
|
||||||
return tt
|
return tt
|
||||||
}
|
}
|
||||||
|
@ -245,22 +287,22 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
for _, t := range meta.Tensors().Items() {
|
for _, t := range meta.Tensors().Items() {
|
||||||
switch {
|
switch {
|
||||||
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
|
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
|
||||||
createTensor(tensor{source: t}, input.bts)
|
createTensor(tensor{source: t}, input.bts, -1)
|
||||||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts)
|
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
|
||||||
}
|
}
|
||||||
case contains(t.Name, "cls", "output", "output_norm"):
|
case contains(t.Name, "cls", "output", "output_norm"):
|
||||||
createTensor(tensor{source: t}, output.bts)
|
createTensor(tensor{source: t}, output.bts, blocks)
|
||||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||||
// TODO: assign vision tensors to the gpu if possible
|
// TODO: assign vision tensors to the gpu if possible
|
||||||
createTensor(tensor{source: t}, output.bts)
|
createTensor(tensor{source: t}, output.bts, blocks)
|
||||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||||
// these tensors should be repeated per layer
|
// these tensors should be repeated per layer
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
createTensor(tensor{
|
createTensor(tensor{
|
||||||
source: t,
|
source: t,
|
||||||
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
||||||
}, layer.bts)
|
}, layer.bts, i)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
layerIndex := -1
|
layerIndex := -1
|
||||||
|
@ -271,10 +313,10 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
}
|
}
|
||||||
|
|
||||||
if layerIndex >= 0 {
|
if layerIndex >= 0 {
|
||||||
createTensor(tensor{source: t}, layers[layerIndex].bts)
|
createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex)
|
||||||
} else {
|
} else {
|
||||||
// load all other tensors on the cpu
|
// load all other tensors on the cpu
|
||||||
createTensor(tensor{source: t}, input.bts)
|
createTensor(tensor{source: t}, input.bts, -1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,8 +329,18 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
}
|
}
|
||||||
|
|
||||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||||
|
for i := range btDeviceMemory[bt].Weights {
|
||||||
|
if btDeviceMemory[bt].Weights[i].Size != 0 {
|
||||||
|
if b != nil {
|
||||||
|
btDeviceMemory[bt].Weights[i].Status = ml.Allocated
|
||||||
|
} else {
|
||||||
|
btDeviceMemory[bt].Weights[i].Status = ml.Failed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if b == nil {
|
if b == nil {
|
||||||
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
|
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
||||||
|
@ -307,73 +359,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var doneBytes atomic.Uint64
|
|
||||||
totalBytes := uint64(n) - meta.Tensors().Offset
|
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
|
||||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
|
||||||
for _, t := range meta.Tensors().Items() {
|
|
||||||
t := t
|
|
||||||
g.Go(func() error {
|
|
||||||
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
|
||||||
for i := range tts {
|
|
||||||
target := targets[t.Name][i]
|
|
||||||
if target == "" {
|
|
||||||
target = t.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
tt, ok := tensors[target]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
tts[i] = tt
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
|
|
||||||
// seeking around within an FD shared between all goroutines.
|
|
||||||
file, err := os.Open(r.Name())
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("file open error", "file", r.Name(), "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
|
||||||
bts := make([]byte, 128*format.KibiByte)
|
|
||||||
|
|
||||||
var s uint64
|
|
||||||
for s < t.Size() {
|
|
||||||
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
|
||||||
if err := ctx.Err(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("file read error", "file", r.Name(), "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tts {
|
|
||||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
s += uint64(n)
|
|
||||||
|
|
||||||
if params.Progress != nil {
|
|
||||||
done := doneBytes.Add(uint64(n))
|
|
||||||
params.Progress(float32(done) / float32(totalBytes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// map devices to backend buffer types so new tensors can be assigned to the correct device
|
// map devices to backend buffer types so new tensors can be assigned to the correct device
|
||||||
deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
|
deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
|
||||||
|
|
||||||
|
@ -397,8 +382,10 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
|
|
||||||
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
|
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
|
||||||
return &Backend{
|
return &Backend{
|
||||||
|
modelPath: modelPath,
|
||||||
flashAttention: params.FlashAttention,
|
flashAttention: params.FlashAttention,
|
||||||
meta: meta,
|
meta: meta,
|
||||||
|
tensorLoadTargets: targets,
|
||||||
tensors: tensors,
|
tensors: tensors,
|
||||||
sched: C.ggml_backend_sched_new(
|
sched: C.ggml_backend_sched_new(
|
||||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||||
|
@ -418,6 +405,8 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}(),
|
}(),
|
||||||
|
requiredMemory: &requiredMemory,
|
||||||
|
btDeviceMemory: btDeviceMemory,
|
||||||
maxGraphNodes: maxGraphNodes,
|
maxGraphNodes: maxGraphNodes,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -426,6 +415,81 @@ func init() {
|
||||||
ml.RegisterBackend("ggml", New)
|
ml.RegisterBackend("ggml", New)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||||
|
var doneBytes atomic.Uint64
|
||||||
|
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
|
||||||
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
|
for _, t := range b.meta.Tensors().Items() {
|
||||||
|
t := t
|
||||||
|
g.Go(func() error {
|
||||||
|
tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
|
||||||
|
for i := range tts {
|
||||||
|
target := b.tensorLoadTargets[t.Name][i]
|
||||||
|
if target == "" {
|
||||||
|
target = t.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
tt, ok := b.tensors[target]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
tts[i] = tt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
|
||||||
|
// seeking around within an FD shared between all goroutines.
|
||||||
|
file, err := os.Open(b.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("file open error", "file", b.modelPath, "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||||
|
bts := make([]byte, 128*format.KibiByte)
|
||||||
|
|
||||||
|
var s uint64
|
||||||
|
for s < t.Size() {
|
||||||
|
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("file read error", "file", b.modelPath, "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tts {
|
||||||
|
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
s += uint64(n)
|
||||||
|
|
||||||
|
if progress != nil {
|
||||||
|
done := doneBytes.Add(uint64(n))
|
||||||
|
progress(float32(done) / float32(totalBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Backend) BackendMemory() ml.BackendMemory {
|
||||||
|
return *b.requiredMemory
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Backend) Config() fs.Config {
|
func (b *Backend) Config() fs.Config {
|
||||||
return b.meta.KV()
|
return b.meta.KV()
|
||||||
}
|
}
|
||||||
|
@ -457,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||||
no_alloc: true,
|
no_alloc: true,
|
||||||
}),
|
}),
|
||||||
allocatedBuffers: &allocatedBuffers,
|
allocatedBuffers: &allocatedBuffers,
|
||||||
|
layer: -1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -483,6 +548,9 @@ type Context struct {
|
||||||
|
|
||||||
// maxGraphNodes is the maximum allowed number of graph nodes in this context
|
// maxGraphNodes is the maximum allowed number of graph nodes in this context
|
||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
|
|
||||||
|
// layer is the graph layer that this context is allocating for - assumed to be cache
|
||||||
|
layer int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Input() ml.Context {
|
func (c *Context) Input() ml.Context {
|
||||||
|
@ -493,6 +561,7 @@ func (c *Context) Input() ml.Context {
|
||||||
buft: c.b.input,
|
buft: c.b.input,
|
||||||
allocatedBuffers: c.allocatedBuffers,
|
allocatedBuffers: c.allocatedBuffers,
|
||||||
maxGraphNodes: c.maxGraphNodes,
|
maxGraphNodes: c.maxGraphNodes,
|
||||||
|
layer: -1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context {
|
||||||
buft: buft,
|
buft: buft,
|
||||||
allocatedBuffers: c.allocatedBuffers,
|
allocatedBuffers: c.allocatedBuffers,
|
||||||
maxGraphNodes: c.maxGraphNodes,
|
maxGraphNodes: c.maxGraphNodes,
|
||||||
|
layer: i,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -544,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Reserve() error {
|
func (c *Context) Reserve() {
|
||||||
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
|
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
|
||||||
C.ggml_backend_sched_reset(c.b.sched)
|
|
||||||
return errors.New("failed to reserve graph")
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||||
for i := range c.b.schedBackends {
|
|
||||||
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
|
// Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations
|
||||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
for _, bt := range c.b.schedBufts {
|
||||||
"size", format.HumanBytes2(uint64(size)))
|
c.b.btDeviceMemory[bt].Graph = ml.Memory{}
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_sched_reset(c.b.sched)
|
for i := range c.b.schedBackends {
|
||||||
|
bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||||
|
|
||||||
return nil
|
graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph
|
||||||
|
graph.Size += uint64(bufferStatus.size)
|
||||||
|
if bufferStatus.allocated && graph.Status != ml.Failed {
|
||||||
|
graph.Status = ml.Allocated
|
||||||
|
} else {
|
||||||
|
graph.Status = ml.Failed
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||||
|
"size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reserved {
|
||||||
|
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) MaxGraphNodes() int {
|
func (c *Context) MaxGraphNodes() int {
|
||||||
|
@ -579,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t {
|
||||||
return ((length + pad - 1) / pad) * pad
|
return ((length + pad - 1) / pad) * pad
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||||
if c.buft == nil {
|
if c.buft == nil {
|
||||||
panic("set Input or Layer before creating tensors")
|
panic("set Input or Layer before creating tensors")
|
||||||
}
|
}
|
||||||
|
@ -602,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||||
|
|
||||||
if len(shape) < 1 || shape[0] == 0 {
|
if len(shape) < 1 || shape[0] == 0 {
|
||||||
var shape C.int64_t = 0
|
var shape C.int64_t = 0
|
||||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
|
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||||
} else if len(shape) > 4 {
|
} else if len(shape) > 4 {
|
||||||
panic("unsupported number of dimensions")
|
panic("unsupported number of dimensions")
|
||||||
}
|
}
|
||||||
|
@ -615,40 +697,43 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||||
|
|
||||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
|
||||||
if b == nil {
|
|
||||||
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
|
|
||||||
}
|
|
||||||
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
|
|
||||||
|
|
||||||
|
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||||
|
if c.layer >= 0 {
|
||||||
|
cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer]
|
||||||
|
|
||||||
|
cache.Size += uint64(size)
|
||||||
|
if b != nil {
|
||||||
|
cache.Status = ml.Allocated
|
||||||
|
} else {
|
||||||
|
cache.Status = ml.Failed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b == nil {
|
||||||
|
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||||
|
}
|
||||||
|
|
||||||
|
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
|
||||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||||
return &Tensor{b: c.b, t: t}, nil
|
return &Tensor{b: c.b, t: t}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
t, err := c.newTensor(dtype, shape)
|
return c.newTensor(dtype, shape)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
t, err := c.newTensor(dtype, shape)
|
t := c.newTensor(dtype, shape)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
C.ggml_set_zero(t.(*Tensor).t)
|
C.ggml_set_zero(t.(*Tensor).t)
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
func checkShape[S ~[]E, E any](s S, shape ...int) {
|
||||||
n := len(s)
|
n := len(s)
|
||||||
|
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range shape {
|
for _, v := range shape {
|
||||||
|
@ -656,44 +741,32 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if n != 1 {
|
if n != 1 {
|
||||||
return fmt.Errorf("invalid shape: %v", shape)
|
panic(fmt.Errorf("invalid shape: %v", shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
|
||||||
}
|
checkShape(s, shape...)
|
||||||
|
|
||||||
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
t := c.newTensor(ml.DTypeF32, shape)
|
||||||
if err := checkShape(s, shape...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t, err := c.newTensor(ml.DTypeF32, shape)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(s) > 0 {
|
if len(s) > 0 {
|
||||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||||
}
|
}
|
||||||
|
|
||||||
return t, nil
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor {
|
||||||
if err := checkShape(s, shape...); err != nil {
|
checkShape(s, shape...)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t, err := c.newTensor(ml.DTypeI32, shape)
|
t := c.newTensor(ml.DTypeI32, shape)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(s) > 0 {
|
if len(s) > 0 {
|
||||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||||
}
|
}
|
||||||
|
|
||||||
return t, nil
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||||
|
@ -711,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||||
arange = append(arange, int32(i))
|
arange = append(arange, int32(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := c.Input().FromIntSlice(arange, len(arange))
|
return c.Input().FromIntSlice(arange, len(arange))
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
default:
|
default:
|
||||||
panic("unsupported dtype for arange")
|
panic("unsupported dtype for arange")
|
||||||
}
|
}
|
||||||
|
@ -867,6 +935,13 @@ func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
|
@ -984,6 +1059,13 @@ func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
|
@ -1055,28 +1137,15 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
|
||||||
ropeTypeNorm C.int = 0
|
|
||||||
ropeTypeNeox C.int = 2
|
|
||||||
ropeTypeMrope C.int = 8
|
|
||||||
ropeTypeVision C.int = 24
|
|
||||||
)
|
|
||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
|
|
||||||
// Default options
|
// Default options
|
||||||
opts := &ml.RopeOptions{
|
opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}}
|
||||||
OriginalContextLen: 131072,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply any provided options
|
// Apply any provided options
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
option(opts)
|
option(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ropeFactors == nil {
|
|
||||||
ropeFactors = &Tensor{b: t.b}
|
|
||||||
}
|
|
||||||
|
|
||||||
dequant := t.t
|
dequant := t.t
|
||||||
if C.ggml_is_quantized(t.t._type) {
|
if C.ggml_is_quantized(t.t._type) {
|
||||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||||
|
@ -1087,11 +1156,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
ctx.(*Context).ctx,
|
ctx.(*Context).ctx,
|
||||||
dequant,
|
dequant,
|
||||||
positionIDs.(*Tensor).t,
|
positions.(*Tensor).t,
|
||||||
ropeFactors.(*Tensor).t,
|
opts.Factors.(*Tensor).t,
|
||||||
C.int(ropeDim),
|
C.int(ropeDim),
|
||||||
C.int(ropeType),
|
C.int(opts.Type),
|
||||||
C.int(opts.OriginalContextLen),
|
C.int(opts.OriginalContextLength),
|
||||||
C.float(ropeBase),
|
C.float(ropeBase),
|
||||||
C.float(ropeScale),
|
C.float(ropeScale),
|
||||||
C.float(0.0),
|
C.float(0.0),
|
||||||
|
|
|
@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph
|
||||||
|
|
||||||
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||||
|
|
||||||
|
struct ggml_allocr_buffer_status {
|
||||||
|
size_t size;
|
||||||
|
bool allocated;
|
||||||
|
};
|
||||||
|
GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||||
|
|
||||||
// Utils
|
// Utils
|
||||||
// Create a buffer and allocate all the tensors in a ggml_context
|
// Create a buffer and allocate all the tensors in a ggml_context
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||||
|
|
|
@ -304,6 +304,12 @@ extern "C" {
|
||||||
|
|
||||||
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
|
||||||
|
struct ggml_backend_buffer_status {
|
||||||
|
size_t size;
|
||||||
|
bool allocated;
|
||||||
|
};
|
||||||
|
GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
|
||||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||||
|
|
||||||
|
|
|
@ -364,6 +364,7 @@ struct node_alloc {
|
||||||
struct ggml_gallocr {
|
struct ggml_gallocr {
|
||||||
ggml_backend_buffer_type_t * bufts; // [n_buffers]
|
ggml_backend_buffer_type_t * bufts; // [n_buffers]
|
||||||
ggml_backend_buffer_t * buffers; // [n_buffers]
|
ggml_backend_buffer_t * buffers; // [n_buffers]
|
||||||
|
size_t *buffer_sizes; // [n_buffers]
|
||||||
struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
|
struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
|
||||||
int n_buffers;
|
int n_buffers;
|
||||||
|
|
||||||
|
@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
|
||||||
galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
|
galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
|
||||||
GGML_ASSERT(galloc->buffers != NULL);
|
GGML_ASSERT(galloc->buffers != NULL);
|
||||||
|
|
||||||
|
galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t));
|
||||||
|
GGML_ASSERT(galloc->buffer_sizes != NULL);
|
||||||
|
|
||||||
galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
|
galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
|
||||||
GGML_ASSERT(galloc->buf_tallocs != NULL);
|
GGML_ASSERT(galloc->buf_tallocs != NULL);
|
||||||
|
|
||||||
|
@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
|
||||||
ggml_hash_set_free(&galloc->hash_set);
|
ggml_hash_set_free(&galloc->hash_set);
|
||||||
free(galloc->hash_values);
|
free(galloc->hash_values);
|
||||||
free(galloc->bufts);
|
free(galloc->bufts);
|
||||||
|
free(galloc->buffer_sizes);
|
||||||
free(galloc->buffers);
|
free(galloc->buffers);
|
||||||
free(galloc->buf_tallocs);
|
free(galloc->buf_tallocs);
|
||||||
free(galloc->node_allocs);
|
free(galloc->node_allocs);
|
||||||
|
@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool success = true;
|
||||||
|
|
||||||
// reallocate buffers if needed
|
// reallocate buffers if needed
|
||||||
for (int i = 0; i < galloc->n_buffers; i++) {
|
for (int i = 0; i < galloc->n_buffers; i++) {
|
||||||
// if the buffer type is used multiple times, we reuse the same buffer
|
// if the buffer type is used multiple times, we reuse the same buffer
|
||||||
|
@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
|
|
||||||
ggml_backend_buffer_free(galloc->buffers[i]);
|
ggml_backend_buffer_free(galloc->buffers[i]);
|
||||||
galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
|
galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
|
||||||
if (galloc->buffers[i] == NULL) {
|
if (galloc->buffers[i]) {
|
||||||
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
||||||
|
galloc->buffer_sizes[i] = new_size;
|
||||||
|
success = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
||||||
|
@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||||
return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
|
return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||||
|
GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
|
||||||
|
|
||||||
|
for (int i = 0; i < buffer_id; i++) {
|
||||||
|
if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) {
|
||||||
|
// This buffer is the same as a previous one due to the same buffer type being used multiple times
|
||||||
|
// (See above.) However, we need a different check because multiple buffers might be NULL in our
|
||||||
|
// case and we still want to know the attempted size.
|
||||||
|
|
||||||
|
struct ggml_allocr_buffer_status status = {0, true};
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL};
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
// utils
|
// utils
|
||||||
|
|
||||||
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
||||||
|
|
|
@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||||
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
||||||
|
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||||
|
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||||
|
|
||||||
|
struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index);
|
||||||
|
struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated};
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||||
|
|
|
@ -3,7 +3,7 @@ package cpu
|
||||||
// #cgo CFLAGS: -O3 -Wno-implicit-function-declaration
|
// #cgo CFLAGS: -O3 -Wno-implicit-function-declaration
|
||||||
// #cgo CXXFLAGS: -std=c++17
|
// #cgo CXXFLAGS: -std=c++17
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/amx -I${SRCDIR}/llamafile -I${SRCDIR}/.. -I${SRCDIR}/../../include
|
// #cgo CPPFLAGS: -I${SRCDIR}/amx -I${SRCDIR}/llamafile -I${SRCDIR}/.. -I${SRCDIR}/../../include
|
||||||
// #cgo CPPFLAGS: -DGGML_USE_LLAMAFILE
|
// #cgo CPPFLAGS: -DNDEBUG -DGGML_USE_LLAMAFILE
|
||||||
// #cgo linux CPPFLAGS: -D_GNU_SOURCE
|
// #cgo linux CPPFLAGS: -D_GNU_SOURCE
|
||||||
// #cgo darwin,arm64 CPPFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
// #cgo darwin,arm64 CPPFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||||
// #cgo darwin,arm64 LDFLAGS: -framework Accelerate
|
// #cgo darwin,arm64 LDFLAGS: -framework Accelerate
|
||||||
|
|
|
@ -4,6 +4,6 @@ package metal
|
||||||
|
|
||||||
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
|
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
|
||||||
|
|
||||||
// #cgo CPPFLAGS: -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
|
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
|
||||||
// #cgo LDFLAGS: -framework Metal -framework MetalKit
|
// #cgo LDFLAGS: -framework Metal -framework MetalKit
|
||||||
import "C"
|
import "C"
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
// fast provides implementations of fast (fused) operations for increased performance.
|
||||||
|
package fast
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
|
||||||
|
type fastRoPE interface {
|
||||||
|
RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoPE applies rotary positional embedding to tensor `t`.
|
||||||
|
func RoPE(ctx ml.Context, t, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor {
|
||||||
|
if t, ok := t.(fastRoPE); ok {
|
||||||
|
return t.RoPE(ctx, positions, dim, base, scale, options...)
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("RoPE not implemented for this tensor type")
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package rope
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/ml"
|
||||||
|
|
||||||
|
// Options contains optional parameters for RoPE function
|
||||||
|
type Options struct {
|
||||||
|
OriginalContextLength int
|
||||||
|
Type int
|
||||||
|
Factors ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOriginalContextLength sets a custom context length
|
||||||
|
func WithOriginalContextLength(n int) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.OriginalContextLength = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithType sets RoPE type to NeoX
|
||||||
|
func WithTypeNeoX() func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Type = 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFactors sets custom rope factors
|
||||||
|
func WithFactors(factors ml.Tensor) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
if factors != nil {
|
||||||
|
opts.Factors = factors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -98,14 +98,8 @@ func Register(name string, f func(fs.Config) (Model, error)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||||
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
|
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||||
r, err := os.Open(modelPath)
|
b, err := ml.NewBackend(modelPath, params)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
|
|
||||||
b, err := ml.NewBackend(ctx, r, params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -134,7 +128,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
meta, _, err := fsggml.Decode(r, -1)
|
meta, err := fsggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -293,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
@ -83,11 +85,10 @@ type SelfAttention struct {
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
|
@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
|
@ -127,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -174,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
|
|
|
@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
pixelValues := ctx.Input().FromFloatSlice(f32s,
|
||||||
m.ImageProcessor.imageSize,
|
m.ImageProcessor.imageSize,
|
||||||
m.ImageProcessor.imageSize,
|
m.ImageProcessor.imageSize,
|
||||||
m.ImageProcessor.numChannels,
|
m.ImageProcessor.numChannels,
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||||
|
@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,7 +75,6 @@ type TextSelfAttention struct {
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
ropeBase := opts.ropeLocalBase
|
ropeBase := opts.ropeLocalBase
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
|
@ -83,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
|
@ -94,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
|
@ -112,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
|
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|
|
@ -1,22 +1,23 @@
|
||||||
package llama
|
package llama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"cmp"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
headDim, ropeDim int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
|
@ -32,10 +33,6 @@ type Model struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
|
||||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
|
||||||
}
|
|
||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
|
@ -57,10 +54,11 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,31 +75,31 @@ type SelfAttention struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
ropeType := uint32(0)
|
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, kqv)
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||||
|
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -122,11 +120,11 @@ type Layer struct {
|
||||||
MLP *MLP
|
MLP *MLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
residual := hiddenState
|
residual := hiddenState
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts)
|
||||||
|
|
||||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||||
// we need logits for.
|
// we need logits for.
|
||||||
|
@ -144,27 +142,19 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
|
|
@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
|
tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
|
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
|
||||||
|
|
||||||
|
@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
pixelValues := tilesLocal
|
pixelValues := tilesLocal
|
||||||
|
|
||||||
if len(pixelsGlobal) > 0 {
|
if len(pixelsGlobal) > 0 {
|
||||||
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
|
tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
|
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
if useRope {
|
if useRope {
|
||||||
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
|
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
|
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.useQKNorm {
|
if opts.useQKNorm {
|
||||||
|
@ -80,7 +82,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||||
|
|
||||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
|
nextStates = nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nextStates
|
return nextStates
|
||||||
|
@ -221,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||||
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
|
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
|
||||||
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
|
@ -250,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
|
ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
|
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
|
||||||
|
|
|
@ -31,11 +31,6 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||||
var _ model.TextProcessor = (*Model)(nil)
|
var _ model.TextProcessor = (*Model)(nil)
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
textModel, err := NewTextModel(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Model{
|
m := &Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
|
@ -52,7 +47,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
TextModel: textModel,
|
TextModel: newTextModel(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: newVisionModel(c),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
MultiModalProjector: newMultiModalProjector(c),
|
MultiModalProjector: newMultiModalProjector(c),
|
||||||
|
@ -119,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||||
|
@ -166,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
package mistral3
|
package mistral3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"cmp"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads, headDim int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
headDim, ropeDim int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
|
@ -36,19 +36,15 @@ type SelfAttention struct {
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(0)
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
headDim := opts.headDim
|
|
||||||
if headDim == 0 {
|
|
||||||
headDim = opts.hiddenSize / opts.numHeads
|
|
||||||
}
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
@ -59,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
@ -125,24 +121,18 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||||
return m.Output.Forward(ctx, hiddenState)
|
return m.Output.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTextModel(c fs.Config) (*TextModel, error) {
|
func newTextModel(c fs.Config) *TextModel {
|
||||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
return &TextModel{
|
||||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
|
||||||
}
|
|
||||||
|
|
||||||
textModel := &TextModel{
|
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
TextOptions: &TextOptions{
|
TextOptions: &TextOptions{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
headDim: int(c.Uint("attention.key_length")),
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return textModel, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
||||||
if err != nil {
|
w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
positionIDs := ctx.Input().FromIntSlice(positions, len(positions))
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
||||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||||
|
@ -170,7 +160,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||||
|
|
||||||
func newVisionModel(c fs.Config) *VisionModel {
|
func newVisionModel(c fs.Config) *VisionModel {
|
||||||
return &VisionModel{
|
return &VisionModel{
|
||||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||||
VisionModelOptions: &VisionModelOptions{
|
VisionModelOptions: &VisionModelOptions{
|
||||||
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
||||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||||
|
|
|
@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
|
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
|
||||||
}
|
}
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
|
pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
|
||||||
if err != nil {
|
aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
||||||
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||||
|
@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
|
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextSelfAttention struct {
|
type TextSelfAttention struct {
|
||||||
|
@ -21,15 +23,14 @@ type TextSelfAttention struct {
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
@ -44,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
|
@ -199,8 +200,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
||||||
|
|
||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
ropeDim int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
|
||||||
|
|
||||||
crossAttentionLayers []int32
|
crossAttentionLayers []int32
|
||||||
}
|
}
|
||||||
|
@ -240,10 +241,10 @@ func newTextModel(c fs.Config) *TextModel {
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
|
||||||
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,6 @@ type VisionSelfAttention struct {
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
|
||||||
Gate ml.Tensor `gguf:"attn_gate"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
@ -25,27 +23,16 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
scores := key.Mulmat(ctx, query)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
||||||
scores = scores.Softmax(ctx)
|
|
||||||
|
|
||||||
attention := value.Mulmat(ctx, scores)
|
|
||||||
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
|
|
||||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
hiddenState = sa.Output.Forward(ctx, attention)
|
|
||||||
return hiddenState
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type VisionMLP struct {
|
type VisionMLP struct {
|
||||||
|
@ -76,21 +63,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts
|
||||||
// self attention
|
// self attention
|
||||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||||
|
|
||||||
if e.AttentionGate != nil {
|
if e.AttentionGate != nil {
|
||||||
hiddenState = hiddenState.Mul(ctx, e.AttentionGate)
|
hiddenState = hiddenState.Mul(ctx, e.AttentionGate)
|
||||||
}
|
}
|
||||||
hiddenState = hiddenState.Add(ctx, residual)
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
residual = hiddenState
|
residual = hiddenState
|
||||||
|
|
||||||
// feed forward
|
|
||||||
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||||
hiddenState = hiddenState.Add(ctx, residual)
|
|
||||||
if e.MLPGate != nil {
|
if e.MLPGate != nil {
|
||||||
hiddenState = hiddenState.Mul(ctx, e.MLPGate)
|
hiddenState = hiddenState.Mul(ctx, e.MLPGate)
|
||||||
}
|
}
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
return hiddenState
|
return hiddenState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,5 +7,7 @@ import (
|
||||||
_ "github.com/ollama/ollama/model/models/llama4"
|
_ "github.com/ollama/ollama/model/models/llama4"
|
||||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
|
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||||
|
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
package qwen2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
headDim, ropeDim int
|
||||||
|
eps, ropeBase, ropeScale float32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenStates.Dim(1)
|
||||||
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
|
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
||||||
|
|
||||||
|
query := attn.Query.Forward(ctx, hiddenStates)
|
||||||
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
key := attn.Key.Forward(ctx, hiddenStates)
|
||||||
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
value := attn.Value.Forward(ctx, hiddenStates)
|
||||||
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
return attn.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||||
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DecoderLayer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
Attention *Attention
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP *MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
residual := hiddenStates
|
||||||
|
|
||||||
|
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
residual = hiddenStates
|
||||||
|
|
||||||
|
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = d.MLP.Forward(ctx, hiddenStates)
|
||||||
|
return hiddenStates.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []DecoderLayer `gguf:"blk"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
hiddenStates = m.Output.Forward(ctx, hiddenStates)
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||||
|
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
m := Model{
|
||||||
|
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Options: Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("qwen2", New)
|
||||||
|
}
|
|
@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
|
||||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pixelValues, grid, nil
|
return pixelValues, grid, nil
|
||||||
}
|
}
|
||||||
|
@ -121,13 +118,14 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 1})
|
result = append(result, input.Input{Token: visionStartToken})
|
||||||
|
|
||||||
// Add the image token with the multimodal tensor data at the first position
|
// Add the image token with the multimodal tensor data at the first position
|
||||||
result = append(result, input.Input{
|
result = append(result, input.Input{
|
||||||
Token: imageToken,
|
Token: imageToken,
|
||||||
Multimodal: inp.Multimodal,
|
Multimodal: inp.Multimodal,
|
||||||
MultimodalHash: inp.MultimodalHash,
|
MultimodalHash: inp.MultimodalHash,
|
||||||
|
SameBatch: patchesPerChunk,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||||
|
@ -141,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,13 +7,15 @@ import (
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
ropeDim, originalContextLength int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim, defaultContextLen uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
|
@ -29,15 +31,14 @@ func NewTextModel(c fs.Config) *TextModel {
|
||||||
m := TextModel{
|
m := TextModel{
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
TextOptions: &TextOptions{
|
TextOptions: &TextOptions{
|
||||||
ctxLen: int(c.Uint("context_length")),
|
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
ropeDim: int(c.Uint("rope.dimension_count", 128)),
|
||||||
|
originalContextLength: int(c.Uint("context_length", 128000)),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
ropeDim: c.Uint("rope.dimension_count", 128),
|
|
||||||
defaultContextLen: c.Uint("context_length", 128000),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
@ -77,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP implements the feed-forward network component with SwiGLU activation
|
// MLP implements the feed-forward network component with SwiGLU activation
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package qwen25vl
|
package qwen25vl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
||||||
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
||||||
|
|
||||||
|
@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := ctx.Input().FromIntSlice(index, len(index))
|
t := ctx.Input().FromIntSlice(index, len(index))
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return t, bounds
|
return t, bounds
|
||||||
}
|
}
|
||||||
|
@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
||||||
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
|
freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create position coordinates (y,x pairs) for the grid
|
// Create position coordinates (y,x pairs) for the grid
|
||||||
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
||||||
|
@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
||||||
coords = append(coords, int32(y), int32(x))
|
coords = append(coords, int32(y), int32(x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
|
pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reshape and permute positions to match spatial merging pattern
|
// Reshape and permute positions to match spatial merging pattern
|
||||||
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
||||||
|
|
|
@ -0,0 +1,233 @@
|
||||||
|
package qwen3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
eps float32
|
||||||
|
ropeBase, ropeScale float32
|
||||||
|
|
||||||
|
keyLength, valueLength int
|
||||||
|
|
||||||
|
numExperts, numExpertsUsed int
|
||||||
|
normTopKProb bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) headDim() int {
|
||||||
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenStates.Dim(1)
|
||||||
|
|
||||||
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
|
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
|
||||||
|
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||||
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP interface {
|
||||||
|
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
type sparse struct {
|
||||||
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
|
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||||
|
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||||
|
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||||
|
routerLogits := mlp.Router.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
routingWeights := routerLogits.Softmax(ctx)
|
||||||
|
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||||
|
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
|
||||||
|
if opts.normTopKProb {
|
||||||
|
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||||
|
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||||
|
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||||
|
|
||||||
|
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
|
||||||
|
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
hiddenStates = hiddenStates.SILU(ctx)
|
||||||
|
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||||
|
|
||||||
|
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|
||||||
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
|
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type dense struct {
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||||
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
*Attention
|
||||||
|
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
residual := hiddenStates
|
||||||
|
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||||
|
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
|
||||||
|
residual = hiddenStates
|
||||||
|
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = d.MLP.Forward(ctx, hiddenStates, opts)
|
||||||
|
return hiddenStates.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
|
||||||
|
*Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Model = (*Model)(nil)
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
layers := make([]Layer, c.Uint("block_count"))
|
||||||
|
for i := range layers {
|
||||||
|
if c.String("general.architecture") == "qwen3moe" {
|
||||||
|
layers[i].MLP = &sparse{}
|
||||||
|
} else {
|
||||||
|
layers[i].MLP = &dense{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Layers: layers,
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
keyLength: int(c.Uint("attention.key_length")),
|
||||||
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
numExperts: int(c.Uint("expert_count")),
|
||||||
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
|
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("qwen3", New)
|
||||||
|
model.Register("qwen3moe", New)
|
||||||
|
}
|
|
@ -95,17 +95,14 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := computeCtx.Reserve()
|
computeCtx.Reserve()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, t := range entry.mm {
|
for i, t := range entry.mm {
|
||||||
if in == t.Tensor {
|
if in == t.Tensor {
|
||||||
if !reserve {
|
if !reserve {
|
||||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil
|
||||||
} else {
|
} else {
|
||||||
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
|
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||||
batch.Outputs[i] = int32(i)
|
batch.Outputs[i] = int32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cache := s.model.Config().Cache
|
cache := s.model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
|
@ -826,16 +823,12 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ctx.Forward(t).Reserve()
|
ctx.Forward(t).Reserve()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) initModel(
|
||||||
ctx context.Context,
|
|
||||||
mpath string,
|
mpath string,
|
||||||
params ml.BackendParams,
|
params ml.BackendParams,
|
||||||
lpath multiLPath,
|
lpath multiLPath,
|
||||||
|
@ -843,21 +836,21 @@ func (s *Server) loadModel(
|
||||||
kvCacheType string,
|
kvCacheType string,
|
||||||
kvSize int,
|
kvSize int,
|
||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
) {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
s.model, err = model.New(ctx, mpath, params)
|
s.model, err = model.New(mpath, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): LoRA loading
|
// TODO(jessegross): LoRA loading
|
||||||
if lpath.String() != "" {
|
if lpath.String() != "" {
|
||||||
panic("loras are not yet implemented")
|
return errors.New("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.cache.enabled && parallel > 1 {
|
if !s.cache.enabled && parallel > 1 {
|
||||||
|
@ -869,7 +862,30 @@ func (s *Server) loadModel(
|
||||||
s.seqs = make([]*Sequence, s.parallel)
|
s.seqs = make([]*Sequence, s.parallel)
|
||||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||||
|
|
||||||
err = s.reserveWorstCaseGraph()
|
return s.reserveWorstCaseGraph()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) load(
|
||||||
|
ctx context.Context,
|
||||||
|
mpath string,
|
||||||
|
params ml.BackendParams,
|
||||||
|
lpath multiLPath,
|
||||||
|
parallel int,
|
||||||
|
kvCacheType string,
|
||||||
|
kvSize int,
|
||||||
|
multiUserCache bool,
|
||||||
|
) {
|
||||||
|
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
|
||||||
|
|
||||||
|
err = s.model.Backend().Load(ctx,
|
||||||
|
func(progress float32) {
|
||||||
|
s.progress = progress
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -913,9 +929,14 @@ func Execute(args []string) error {
|
||||||
status: llm.ServerStatusLoadingModel,
|
status: llm.ServerStatusLoadingModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
server.ready.Add(1)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// TODO(jessegross): Parameters that need to be implemented:
|
// TODO(jessegross): Parameters that need to be implemented:
|
||||||
// no-mmap
|
// no-mmap
|
||||||
// mlock
|
|
||||||
|
|
||||||
var tensorSplitFloats []float32
|
var tensorSplitFloats []float32
|
||||||
if *tensorSplit != "" {
|
if *tensorSplit != "" {
|
||||||
|
@ -928,9 +949,6 @@ func Execute(args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
Progress: func(progress float32) {
|
|
||||||
server.progress = progress
|
|
||||||
},
|
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
|
@ -938,14 +956,7 @@ func Execute(args []string) error {
|
||||||
FlashAttention: *flashAttention,
|
FlashAttention: *flashAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
|
||||||
|
|
||||||
go server.run(ctx)
|
go server.run(ctx)
|
||||||
|
|
||||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||||
|
|
|
@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||||
}
|
}
|
||||||
defer bin.Close()
|
defer bin.Close()
|
||||||
|
|
||||||
f, _, err := ggml.Decode(bin, -1)
|
f, err := ggml.Decode(bin, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f, _, err := ggml.Decode(temp, 1024)
|
f, err := ggml.Decode(temp, 1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
|
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -501,47 +501,26 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||||
return nil, errOnlyGGUFSupported
|
return nil, errOnlyGGUFSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
stat, err := blob.Stat()
|
f, err := ggml.Decode(blob, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var offset int64
|
|
||||||
for offset < stat.Size() {
|
|
||||||
f, n, err := ggml.Decode(blob, 1024)
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
mediatype := "application/vnd.ollama.image.model"
|
mediatype := "application/vnd.ollama.image.model"
|
||||||
if f.KV().Kind() == "adapter" {
|
if f.KV().Kind() == "adapter" {
|
||||||
mediatype = "application/vnd.ollama.image.adapter"
|
mediatype = "application/vnd.ollama.image.adapter"
|
||||||
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
|
} else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" {
|
||||||
|
// if a model has vision.block_count but not block_count, it is a standalone vision model
|
||||||
mediatype = "application/vnd.ollama.image.projector"
|
mediatype = "application/vnd.ollama.image.projector"
|
||||||
}
|
}
|
||||||
|
|
||||||
var layer Layer
|
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||||
if digest != "" && n == stat.Size() && offset == 0 {
|
|
||||||
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("could not create new layer from layer", "error", err)
|
slog.Debug("could not create new layer from layer", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
|
|
||||||
if layer.Digest == "" {
|
|
||||||
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
layers = append(layers, &layerGGML{layer, f})
|
layers = append(layers, &layerGGML{layer, f})
|
||||||
offset = n
|
|
||||||
}
|
|
||||||
|
|
||||||
return detectChatTemplate(layers)
|
return detectChatTemplate(layers)
|
||||||
}
|
}
|
||||||
|
|
|
@ -464,6 +464,10 @@ type downloadOpts struct {
|
||||||
|
|
||||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||||
|
if opts.digest == "" {
|
||||||
|
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is is empty")
|
||||||
|
}
|
||||||
|
|
||||||
fp, err := GetBlobsPath(opts.digest)
|
fp, err := GetBlobsPath(opts.digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|
|
@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
f, _, err := ggml.Decode(r, 1024)
|
f, err := ggml.Decode(r, 1024)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||||
|
|
126
server/model.go
126
server/model.go
|
@ -10,9 +10,6 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
@ -64,7 +61,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||||
}
|
}
|
||||||
defer blob.Close()
|
defer blob.Close()
|
||||||
|
|
||||||
f, _, err := ggml.Decode(blob, -1)
|
f, err := ggml.Decode(blob, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) {
|
||||||
|
|
||||||
return "unknown", nil
|
return "unknown", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseObjects(s string) []map[string]any {
|
|
||||||
var objs []map[string]any
|
|
||||||
for offset := 0; offset < len(s); {
|
|
||||||
var obj map[string]any
|
|
||||||
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
|
||||||
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
|
||||||
break
|
|
||||||
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
|
||||||
// skip over any syntax errors
|
|
||||||
offset += int(syntax.Offset)
|
|
||||||
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
|
||||||
// skip over any unmarshalable types
|
|
||||||
offset += int(unmarshalType.Offset)
|
|
||||||
} else if err != nil {
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
offset += int(decoder.InputOffset())
|
|
||||||
objs = append(objs, obj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return objs
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
|
||||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
|
||||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
|
||||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
|
||||||
if t, ok := n.(*parse.RangeNode); ok {
|
|
||||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
if tmpl == nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
|
||||||
"ToolCalls": {
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "@@name@@",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"@@argument@@": 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
templateObjects := parseObjects(b.String())
|
|
||||||
if len(templateObjects) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// find the keys that correspond to the name and arguments fields
|
|
||||||
var name, arguments string
|
|
||||||
for k, v := range templateObjects[0] {
|
|
||||||
switch v.(type) {
|
|
||||||
case string:
|
|
||||||
name = k
|
|
||||||
case map[string]any:
|
|
||||||
arguments = k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if name == "" || arguments == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
responseObjects := parseObjects(s)
|
|
||||||
if len(responseObjects) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// collect all nested objects
|
|
||||||
var collect func(any) []map[string]any
|
|
||||||
collect = func(obj any) (all []map[string]any) {
|
|
||||||
switch o := obj.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
all = append(all, o)
|
|
||||||
for _, v := range o {
|
|
||||||
all = append(all, collect(v)...)
|
|
||||||
}
|
|
||||||
case []any:
|
|
||||||
for _, v := range o {
|
|
||||||
all = append(all, collect(v)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return all
|
|
||||||
}
|
|
||||||
|
|
||||||
var objs []map[string]any
|
|
||||||
for _, p := range responseObjects {
|
|
||||||
objs = append(objs, collect(p)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
|
||||||
for _, kv := range objs {
|
|
||||||
n, nok := kv[name].(string)
|
|
||||||
a, aok := kv[arguments].(map[string]any)
|
|
||||||
if nok && aok {
|
|
||||||
toolCalls = append(toolCalls, api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: n,
|
|
||||||
Arguments: a,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return toolCalls, len(toolCalls) > 0
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,179 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/template"
|
|
||||||
)
|
|
||||||
|
|
||||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
bts, err := os.ReadFile(filepath.Join(base, name))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes.NewBuffer(bts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExecuteWithTools(t *testing.T) {
|
|
||||||
p := filepath.Join("testdata", "tools")
|
|
||||||
cases := []struct {
|
|
||||||
model string
|
|
||||||
output string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
|
||||||
|
|
||||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
|
||||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
|
||||||
|
|
||||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
|
||||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
|
||||||
{"command-r-plus", "Action: ```json" + `
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"tool_name": "get_current_weather",
|
|
||||||
"parameters": {
|
|
||||||
"format": "fahrenheit",
|
|
||||||
"location": "San Francisco, CA"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tool_name": "get_current_weather",
|
|
||||||
"parameters": {
|
|
||||||
"format": "celsius",
|
|
||||||
"location": "Toronto, Canada"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
` + "```", true},
|
|
||||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
|
||||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
|
||||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
|
||||||
{"llama3-groq-tool-use", `<tool_call>
|
|
||||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
|
||||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
|
||||||
</tool_call>`, true},
|
|
||||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
|
||||||
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []api.Tool
|
|
||||||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var messages []api.Message
|
|
||||||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
calls := []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_current_weather",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"format": "fahrenheit",
|
|
||||||
"location": "San Francisco, CA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_current_weather",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"format": "celsius",
|
|
||||||
"location": "Toronto, Canada",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.model, func(t *testing.T) {
|
|
||||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("template", func(t *testing.T) {
|
|
||||||
var actual bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("parse", func(t *testing.T) {
|
|
||||||
m := &Model{Template: tmpl}
|
|
||||||
actual, ok := m.parseToolCalls(tt.output)
|
|
||||||
if ok != tt.ok {
|
|
||||||
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.ok {
|
|
||||||
if diff := cmp.Diff(actual, calls); diff != "" {
|
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseObjects(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want []map[string]any
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
|
||||||
want: []map[string]any{
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
|
|
||||||
want: []map[string]any{
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
|
|
||||||
want: []map[string]any{
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: `{"name": "get_current_weather", "arguments": `,
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.input, func(t *testing.T) {
|
|
||||||
got := parseObjects(tc.input)
|
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -116,7 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
|
||||||
func GetManifestPath() (string, error) {
|
func GetManifestPath() (string, error) {
|
||||||
path := filepath.Join(envconfig.Models(), "manifests")
|
path := filepath.Join(envconfig.Models(), "manifests")
|
||||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return path, nil
|
return path, nil
|
||||||
|
@ -139,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return path, nil
|
return path, nil
|
||||||
|
|
|
@ -120,15 +120,31 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||||
|
|
||||||
if newType.IsQuantized() {
|
if newType.IsQuantized() {
|
||||||
nx := shape[0]
|
nx := shape[0]
|
||||||
ny := uint64(1)
|
|
||||||
if len(shape) > 1 {
|
|
||||||
ny = shape[1]
|
|
||||||
}
|
|
||||||
qk_k := newType.BlockSize()
|
qk_k := newType.BlockSize()
|
||||||
|
|
||||||
|
// Check if first dimension is divisible by block size
|
||||||
if nx%qk_k != 0 {
|
if nx%qk_k != 0 {
|
||||||
slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String()))
|
// Store the original type for logging
|
||||||
|
originalType := newType
|
||||||
|
|
||||||
|
// Select appropriate fallback based on original type
|
||||||
|
switch newType {
|
||||||
|
case fsggml.TensorTypeQ4_K:
|
||||||
|
newType = fsggml.TensorTypeQ5_0
|
||||||
|
case fsggml.TensorTypeQ5_K:
|
||||||
|
newType = fsggml.TensorTypeQ5_1
|
||||||
|
case fsggml.TensorTypeQ6_K:
|
||||||
|
newType = fsggml.TensorTypeQ8_0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final check - if still incompatible, fall back to F16
|
||||||
|
if nx%newType.BlockSize() != 0 {
|
||||||
newType = fsggml.TensorTypeF16
|
newType = fsggml.TensorTypeF16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s",
|
||||||
|
nx, qk_k, originalType.String(), newType.String()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return newType
|
return newType
|
||||||
}
|
}
|
||||||
|
|
|
@ -271,7 +271,7 @@ func TestQuantizeModel(t *testing.T) {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
meta, _, err := fsggml.Decode(fp, -1)
|
meta, err := fsggml.Decode(fp, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -303,7 +303,7 @@ func TestQuantizeModel(t *testing.T) {
|
||||||
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
||||||
}
|
}
|
||||||
defer fpNew.Close()
|
defer fpNew.Close()
|
||||||
newMeta, _, err := fsggml.Decode(fpNew, -1)
|
newMeta, err := fsggml.Decode(fpNew, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,7 @@ import (
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/tools"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
@ -1482,11 +1483,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var toolParser *tools.Parser
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
toolParser, err = tools.NewParser(m.Template.Template)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to create tool parser", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
var sb strings.Builder
|
|
||||||
var toolCallIndex int = 0
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
|
@ -1512,37 +1522,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
if len(req.Tools) > 0 {
|
||||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
toolCalls, content := toolParser.Add(r.Content)
|
||||||
// handlers
|
if len(content) > 0 {
|
||||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
res.Message.Content = content
|
||||||
ch <- res
|
} else if len(toolCalls) > 0 {
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Streaming tool calls:
|
|
||||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
|
||||||
// This ensures that content is cleared from the message on the last chunk sent
|
|
||||||
sb.WriteString(r.Content)
|
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
for i := range toolCalls {
|
|
||||||
toolCalls[i].Function.Index = toolCallIndex
|
|
||||||
toolCallIndex++
|
|
||||||
}
|
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
sb.Reset()
|
} else {
|
||||||
|
if r.Done {
|
||||||
ch <- res
|
ch <- res
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
|
||||||
// Send any remaining content if no tool calls were detected
|
|
||||||
if toolCallIndex == 0 {
|
|
||||||
res.Message.Content = sb.String()
|
|
||||||
}
|
}
|
||||||
ch <- res
|
ch <- res
|
||||||
}
|
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
|
@ -1551,11 +1545,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
var resp api.ChatResponse
|
var resp api.ChatResponse
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
for rr := range ch {
|
for rr := range ch {
|
||||||
switch t := rr.(type) {
|
switch t := rr.(type) {
|
||||||
case api.ChatResponse:
|
case api.ChatResponse:
|
||||||
sb.WriteString(t.Message.Content)
|
sb.WriteString(t.Message.Content)
|
||||||
resp = t
|
resp = t
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
||||||
|
}
|
||||||
case gin.H:
|
case gin.H:
|
||||||
msg, ok := t["error"].(string)
|
msg, ok := t["error"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -1571,12 +1569,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Message.Content = sb.String()
|
resp.Message.Content = sb.String()
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
if len(req.Tools) > 0 {
|
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
|
||||||
resp.Message.ToolCalls = toolCalls
|
resp.Message.ToolCalls = toolCalls
|
||||||
resp.Message.Content = ""
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
|
|
|
@ -387,6 +387,17 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
runner.refMu.Unlock()
|
runner.refMu.Unlock()
|
||||||
slog.Debug("duplicate expired event, ignoring", "runner", runner)
|
slog.Debug("duplicate expired event, ignoring", "runner", runner)
|
||||||
|
} else if runner.pid != runnerToUnload.pid {
|
||||||
|
// If the pids do not match, we likely had multiple load
|
||||||
|
// failures for the same model in quick succession due to
|
||||||
|
// request context canceled and are draining the queue of
|
||||||
|
// events. Ensure the orphaned runner is properly shut down, but
|
||||||
|
// do not delete the mismatched loaded runner, or wait for VRAM
|
||||||
|
// convergence.
|
||||||
|
slog.Debug("orphaned runner shutting down", "orphan", runner, "loaded", runnerToUnload)
|
||||||
|
runner.unload()
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
runner.refMu.Unlock()
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
||||||
finished := runner.waitForVRAMRecovery()
|
finished := runner.waitForVRAMRecovery()
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
|
||||||
|
{{ if .System }}{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||||
|
|
||||||
|
You are a helpful assistant with tool calling capabilities.
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||||
|
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
{{- if and $.Tools $last }}
|
||||||
|
|
||||||
|
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||||
|
|
||||||
|
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||||
|
|
||||||
|
{{ range $.Tools }}
|
||||||
|
{{- . }}
|
||||||
|
{{ end }}
|
||||||
|
{{ .Content }}<|eot_id|>
|
||||||
|
{{- else }}
|
||||||
|
|
||||||
|
{{ .Content }}<|eot_id|>
|
||||||
|
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ end }}
|
||||||
|
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
{{- if .ToolCalls }}
|
||||||
|
{{ range .ToolCalls }}
|
||||||
|
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
||||||
|
{{- else }}
|
||||||
|
|
||||||
|
{{ .Content }}
|
||||||
|
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
|
||||||
|
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
|
@ -0,0 +1,24 @@
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
|
||||||
|
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||||
|
|
||||||
|
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
22<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||||
|
|
||||||
|
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||||
|
|
||||||
|
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||||
|
|
||||||
|
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
||||||
|
{{- else if .Messages }}
|
||||||
|
{{- if or .System .Tools }}<|im_start|>system
|
||||||
|
{{- if .System }}
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Tools }}
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{{- range .Tools }}
|
||||||
|
{"type": "function", "function": {{ .Function }}}
|
||||||
|
{{- end }}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>, "arguments": <args-json-object>}
|
||||||
|
</tool_call>
|
||||||
|
{{- end }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||||
|
{{- if eq .Role "user" }}<|im_start|>user
|
||||||
|
{{ .Content }}<|im_end|>
|
||||||
|
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||||
|
{{ if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{ end }}</tool_call>
|
||||||
|
{{- end }}{{ if not $last }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
{{ .Content }}
|
||||||
|
</tool_response><|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- else }}
|
||||||
|
{{- if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
|
{{ .Prompt }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
|
@ -0,0 +1,31 @@
|
||||||
|
<|im_start|>system
|
||||||
|
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>, "arguments": <args-json-object>}
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What's the weather like today in Paris?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
22
|
||||||
|
</tool_response><|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
|
@ -0,0 +1,50 @@
|
||||||
|
{{- if .Messages }}
|
||||||
|
{{- if or .System .Tools }}<|im_start|>system
|
||||||
|
{{- if .System }}
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Tools }}
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{{- range .Tools }}
|
||||||
|
{"type": "function", "function": {{ .Function }}}
|
||||||
|
{{- end }}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>, "arguments": <args-json-object>}
|
||||||
|
</tool_call>
|
||||||
|
{{- end }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||||
|
{{- if eq .Role "user" }}<|im_start|>user
|
||||||
|
{{ .Content }}<|im_end|>
|
||||||
|
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||||
|
{{ if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{ end }}</tool_call>
|
||||||
|
{{- end }}{{ if not $last }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
{{ .Content }}
|
||||||
|
</tool_response><|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
|
{{- else }}
|
||||||
|
{{- if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
|
{{ .Prompt }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
|
@ -0,0 +1,31 @@
|
||||||
|
<|im_start|>system
|
||||||
|
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>, "arguments": <args-json-object>}
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What's the weather like today in Paris?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
22
|
||||||
|
</tool_response><|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
|
@ -0,0 +1,253 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
gotmpl "text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errInvalidToolCall = errors.New("invalid tool call format")
|
||||||
|
errAccumulateMore = errors.New("need to accumulate more content")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parser struct {
|
||||||
|
greedyParseJSON bool
|
||||||
|
prefix string
|
||||||
|
prefixFound bool
|
||||||
|
tmpl gotmpl.Template
|
||||||
|
sb strings.Builder
|
||||||
|
index int
|
||||||
|
name string
|
||||||
|
arguments string
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - s: The string to parse
|
||||||
|
// - name: The field name from template that identifies the tool call name
|
||||||
|
// - arguments: The field name from template that identifies the tool call arguments
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []api.ToolCall: The parsed tool calls if successful
|
||||||
|
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
||||||
|
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
|
||||||
|
// Check for balanced braces before attempting to parse
|
||||||
|
braceCount := 0
|
||||||
|
squareCount := 0
|
||||||
|
startIndex := -1
|
||||||
|
var rawToolCalls []string
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
||||||
|
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
|
||||||
|
for i, c := range s {
|
||||||
|
switch c {
|
||||||
|
case '{':
|
||||||
|
braceCount++
|
||||||
|
if startIndex == -1 {
|
||||||
|
startIndex = i
|
||||||
|
}
|
||||||
|
case '}':
|
||||||
|
braceCount--
|
||||||
|
if braceCount == 0 {
|
||||||
|
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
|
||||||
|
startIndex = -1
|
||||||
|
}
|
||||||
|
case '[':
|
||||||
|
if trackSquareBrackets {
|
||||||
|
squareCount++
|
||||||
|
}
|
||||||
|
case ']':
|
||||||
|
if trackSquareBrackets {
|
||||||
|
squareCount--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Negative means we have an extra closing brace/bracket
|
||||||
|
if braceCount < 0 || squareCount < 0 {
|
||||||
|
return nil, errInvalidToolCall
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If braces/brackets aren't balanced, need more input
|
||||||
|
if braceCount > 0 || squareCount > 0 {
|
||||||
|
return nil, errAccumulateMore
|
||||||
|
}
|
||||||
|
|
||||||
|
t := strings.TrimSpace(s)
|
||||||
|
if len(t) == 0 {
|
||||||
|
return nil, errAccumulateMore
|
||||||
|
}
|
||||||
|
// If the input is a single square bracket, it's not a valid tool call
|
||||||
|
if t[0] == '[' && len(t) == 1 {
|
||||||
|
return nil, errAccumulateMore
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt full unmarshal of the JSON
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
for _, rawToolCall := range rawToolCalls {
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect nested objects that could contain tool calls
|
||||||
|
objs := collect(resp)
|
||||||
|
if len(objs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract tool calls from objects
|
||||||
|
for _, kv := range objs {
|
||||||
|
n, nok := kv[name].(string)
|
||||||
|
a, aok := kv[arguments].(map[string]any)
|
||||||
|
if nok && aok {
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: n,
|
||||||
|
Arguments: a,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
slog.Debug("No valid tool call found in object.", "object", kv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid JSON, no tool calls found
|
||||||
|
if len(toolCalls) == 0 {
|
||||||
|
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||||
|
return nil, errInvalidToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCalls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The processed string with prefix removed if found
|
||||||
|
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
||||||
|
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||||
|
if s == "" || p.prefix == "" {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for prefix at start of string
|
||||||
|
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||||
|
// Found prefix at start - accumulate for potential tool
|
||||||
|
p.prefixFound = true
|
||||||
|
return cut, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prefix overlaps end of string
|
||||||
|
if idx := suffixOverlap(s, p.prefix); idx != -1 {
|
||||||
|
// Return everything except overlapping portion
|
||||||
|
p.sb.Reset()
|
||||||
|
p.sb.WriteString(s[idx:])
|
||||||
|
return s[:idx], errAccumulateMore
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prefix appears in middle of string
|
||||||
|
if idx := strings.Index(s, p.prefix); idx != -1 {
|
||||||
|
// Save remainder starting at prefix for next pass
|
||||||
|
p.sb.Reset()
|
||||||
|
p.sb.WriteString(strings.TrimSpace(s[idx:]))
|
||||||
|
// Return everything before prefix
|
||||||
|
return s[:idx], errAccumulateMore
|
||||||
|
}
|
||||||
|
|
||||||
|
// No partial prefix found
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add processes a string input to parse tool calls and content.
|
||||||
|
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - tools: Any parsed tool calls
|
||||||
|
// - content: Non-tool call content
|
||||||
|
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||||
|
p.sb.WriteString(s)
|
||||||
|
s = p.sb.String()
|
||||||
|
|
||||||
|
// Check for prefix pattern in input
|
||||||
|
s, err := p.checkPrefix(s)
|
||||||
|
if err != nil {
|
||||||
|
// Need more input to complete prefix
|
||||||
|
return nil, s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||||
|
if !p.greedyParseJSON && !p.prefixFound {
|
||||||
|
p.sb.Reset()
|
||||||
|
return nil, s
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errAccumulateMore) {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
p.sb.Reset()
|
||||||
|
// Only do greedy JSON parsing if there is no prefix from template
|
||||||
|
if p.prefix != "" {
|
||||||
|
p.greedyParseJSON = false
|
||||||
|
}
|
||||||
|
if p.index != 0 && p.prefix == "" {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
if p.prefixFound {
|
||||||
|
// Drop tokens since prefix was found
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
return nil, s
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
tc.Function.Index = p.index
|
||||||
|
p.index++
|
||||||
|
}
|
||||||
|
|
||||||
|
p.sb.Reset()
|
||||||
|
return toolCalls, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||||
|
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||||
|
//
|
||||||
|
// Returns an error if the template does not contain valid tool call formatting.
|
||||||
|
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||||
|
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tt, err := toolTemplate(parsed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tp := toolPrefix(templateToProcess)
|
||||||
|
|
||||||
|
name, arguments, err := extractToolArgs(tt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Parser{
|
||||||
|
tmpl: *tt,
|
||||||
|
sb: strings.Builder{},
|
||||||
|
prefix: tp,
|
||||||
|
greedyParseJSON: true,
|
||||||
|
name: name,
|
||||||
|
arguments: arguments,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -0,0 +1,673 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewBuffer(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseJSONToolCalls(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
nameField string
|
||||||
|
argsField string
|
||||||
|
wantToolCalls []api.ToolCall
|
||||||
|
wantErr error
|
||||||
|
prefix string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid single tool call",
|
||||||
|
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test_tool",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": "value1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete JSON",
|
||||||
|
input: `{"name": "test_tool", "arguments": {"arg1": `,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantErr: errAccumulateMore,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON",
|
||||||
|
input: `not json at all`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantErr: errInvalidToolCall,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing required fields",
|
||||||
|
input: `{"other": "field"}`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantErr: errInvalidToolCall,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls in array",
|
||||||
|
input: `[
|
||||||
|
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||||
|
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||||
|
]`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool1",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": float64(1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool2",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg2": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls without array",
|
||||||
|
input: `
|
||||||
|
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||||
|
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||||
|
`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool1",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": float64(1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool2",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg2": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls with text after",
|
||||||
|
input: `
|
||||||
|
{"name": "tool1", "arguments": {"arg1": 1}} text
|
||||||
|
{"name": "tool2", "arguments": {"arg2": "value"}} text
|
||||||
|
`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool1",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": float64(1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool2",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg2": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second tool call in array",
|
||||||
|
input: `
|
||||||
|
, {"name": "tool2", "arguments": {"arg2": "value"}}
|
||||||
|
`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool2",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg2": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
// a bad JSON would not return any tool calls or content as it would always accumulate more
|
||||||
|
{
|
||||||
|
name: "unbalanced square brackets",
|
||||||
|
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantErr: errAccumulateMore,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete square brackets",
|
||||||
|
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantErr: errAccumulateMore,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested arrays in arguments",
|
||||||
|
input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`,
|
||||||
|
nameField: "name",
|
||||||
|
argsField: "arguments",
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool1",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": []any{float64(1), float64(2), []any{"nested", "array"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
prefix: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix)
|
||||||
|
|
||||||
|
if err != tt.wantErr {
|
||||||
|
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(gotCalls) != 0 && tt.wantErr != nil {
|
||||||
|
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
|
||||||
|
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCalls(t *testing.T) {
|
||||||
|
p := filepath.Join("testdata")
|
||||||
|
t1 := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
t2 := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Toronto, Canada",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
output string
|
||||||
|
expectedToolCall []api.ToolCall
|
||||||
|
expectedTokens string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "mistral malformed json with tool calls prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral multiple tool calls without prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral tool calls with text between no prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral valid json with tool calls prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral multiple tool calls with text between and prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral incomplete json with tool calls prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral invalid tool call with explanatory text no prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral tool calls without prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command r plus tool calls with json block format",
|
||||||
|
model: "command-r-plus",
|
||||||
|
output: "Action: ```json" + `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Toronto, Canada"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
` + "```",
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "firefunction tool calls with functools prefix",
|
||||||
|
model: "firefunction",
|
||||||
|
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "llama3 groq single tool call with xml tags",
|
||||||
|
model: "llama3-groq-tool-use",
|
||||||
|
output: `<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
|
</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "xlam tool calls with wrapper object",
|
||||||
|
model: "xlam",
|
||||||
|
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 single tool call with prefix",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 multiple tool calls with and without prefix",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 plain text response no tool calls",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls with trailing text",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "some tokens after call",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls with initial text",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls with prefix and trailing text",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls with prefix and initial text",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "some tokens before call",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls without and with prefix",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls without and with prefix and text between",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call> some tokens after call`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "some tokens between",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `hi [{"options": "foo"}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `hi [{"options": "foo"}]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls with prefix and invalid tool call",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `<tool_call> [{"options": "foo"}] </tool_call> `,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: ``,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think>Okay, let me think what tool we should use...</think> <tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 empty think prefix with tool prefix and valid tool call",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think><tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: `<think></think>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with prefix in template, no prefix in output",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with prefix in template, prefix in output",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, no prefix in output",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, no prefix in output, single tool call",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, prefix in output, multiple tool calls in list",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `<tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, prefix in output, individual tool calls",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `<tool_call> {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `<tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with prefix in template, no prefix in output, tokens before",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with prefix in template, prefix in output, tokens after",
|
||||||
|
model: "qwen2.5",
|
||||||
|
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, no prefix in output, tokens after",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, no prefix in output, tokens before",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `some tokens before`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, prefix in output, tokens after",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `<tool_call>
|
||||||
|
[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `<tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without without prefix, match all jsons",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "model outputs some text",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model flushes tokens if tool call doesn't match",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model flushes tokens if tool call doesn't match array",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var tools []api.Tool
|
||||||
|
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []api.Message
|
||||||
|
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("template", func(t *testing.T) {
|
||||||
|
actual := &bytes.Buffer{} // Create new buffer for each test
|
||||||
|
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse", func(t *testing.T) {
|
||||||
|
tp, err := NewParser(tmpl.Template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := []api.ToolCall{}
|
||||||
|
var gotTokens strings.Builder
|
||||||
|
|
||||||
|
tokens := strings.Fields(tt.output)
|
||||||
|
for _, tok := range tokens {
|
||||||
|
s := " " + tok
|
||||||
|
|
||||||
|
toolCalls, content := tp.Add(s)
|
||||||
|
if len(content) > 0 {
|
||||||
|
gotTokens.WriteString(content)
|
||||||
|
} else if len(toolCalls) > 0 {
|
||||||
|
got = append(got, toolCalls...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare tool calls if we expect any
|
||||||
|
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
|
||||||
|
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare tokens if we expect any
|
||||||
|
stripped := strings.TrimSpace(gotTokens.String())
|
||||||
|
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
|
||||||
|
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
|
||||||
|
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,227 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
gotmpl "text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
||||||
|
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
||||||
|
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The extracted text following the first ".ToolCalls" condition found
|
||||||
|
// - bool: Whether a ".ToolCalls" condition was found in the template
|
||||||
|
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
|
||||||
|
if tmpl == nil || tmpl.Tree == nil {
|
||||||
|
slog.Debug("template or tree is nil")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var result string
|
||||||
|
var found bool
|
||||||
|
|
||||||
|
var walk func(nodes []parse.Node)
|
||||||
|
walk = func(nodes []parse.Node) {
|
||||||
|
for _, node := range nodes {
|
||||||
|
if found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *parse.IfNode:
|
||||||
|
if isToolCallsNode(n) {
|
||||||
|
// Collect immediate TextNode(s) at start of IfNode's list
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, innerNode := range n.List.Nodes {
|
||||||
|
if tn, ok := innerNode.(*parse.TextNode); ok {
|
||||||
|
sb.Write(tn.Text)
|
||||||
|
} else {
|
||||||
|
// Stop at first non-text node
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = sb.String()
|
||||||
|
found = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Recurse into child nodes
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
case *parse.ListNode:
|
||||||
|
walk(n.Nodes)
|
||||||
|
case *parse.RangeNode:
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
case *parse.WithNode:
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Continue to next node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
walk(tmpl.Tree.Root.Nodes)
|
||||||
|
return result, found
|
||||||
|
}
|
||||||
|
|
||||||
|
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
||||||
|
func isToolCallsNode(n *parse.IfNode) bool {
|
||||||
|
for _, cmd := range n.Pipe.Cmds {
|
||||||
|
for _, arg := range cmd.Args {
|
||||||
|
if field, ok := arg.(*parse.FieldNode); ok {
|
||||||
|
if slices.Contains(field.Ident, "ToolCalls") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||||
|
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
tokenText = strings.TrimSpace(tokenText)
|
||||||
|
tokenText = strings.ReplaceAll(tokenText, "\r", "")
|
||||||
|
tokenText = strings.ReplaceAll(tokenText, "\n", " ")
|
||||||
|
|
||||||
|
return tokenText
|
||||||
|
}
|
||||||
|
|
||||||
|
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||||
|
// - error: Error if parsing failed
|
||||||
|
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
|
||||||
|
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||||
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil, errors.New("failed to find tool template")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - int: The starting index in s where the suffix overlap begins
|
||||||
|
func suffixOverlap(s, prefix string) int {
|
||||||
|
max := min(len(prefix), len(s))
|
||||||
|
for i := max; i > 0; i-- {
|
||||||
|
if strings.HasSuffix(s, prefix[:i]) {
|
||||||
|
return len(s) - i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The name of the tool call
|
||||||
|
// - string: The arguments of the tool call
|
||||||
|
// - error: Error if parsing failed
|
||||||
|
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
|
"ToolCalls": {
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "@@name@@",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"@@argument@@": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj any
|
||||||
|
err = json.Unmarshal(b.Bytes(), &obj)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var objs []map[string]any
|
||||||
|
switch v := obj.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
objs = []map[string]any{v}
|
||||||
|
case []map[string]any:
|
||||||
|
objs = v
|
||||||
|
case []any:
|
||||||
|
objs = collect(v)
|
||||||
|
}
|
||||||
|
if len(objs) == 0 {
|
||||||
|
return "", "", errors.New("no template objects found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the keys that correspond to the name and arguments fields
|
||||||
|
for k, v := range objs[0] {
|
||||||
|
switch v.(type) {
|
||||||
|
case string:
|
||||||
|
name = k
|
||||||
|
case map[string]any:
|
||||||
|
arguments = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" || arguments == "" {
|
||||||
|
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
|
||||||
|
return "", "", errors.New("missing required fields in tool call template")
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, arguments, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// collect recursively traverses an object to collect all nested maps
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []map[string]any: A slice of all nested maps found in the object
|
||||||
|
func collect(obj any) []map[string]any {
|
||||||
|
var all []map[string]any
|
||||||
|
switch o := obj.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
all = append(all, o)
|
||||||
|
for _, v := range o {
|
||||||
|
all = append(all, collect(v)...)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, v := range o {
|
||||||
|
all = append(all, collect(v)...)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
|
@ -0,0 +1,464 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
gotmpl "text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractToolCallsFormat(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
found bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil template",
|
||||||
|
template: "",
|
||||||
|
want: "",
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "basic tool call with text",
|
||||||
|
template: "{{if .ToolCalls}}Hello world{{end}}",
|
||||||
|
want: "Hello world",
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with json format",
|
||||||
|
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||||
|
want: "```json\n",
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call in range",
|
||||||
|
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
|
||||||
|
want: "",
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiple text nodes",
|
||||||
|
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||||
|
want: "First text",
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested if without tool calls",
|
||||||
|
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
|
||||||
|
want: "",
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tc.template)
|
||||||
|
if err != nil && tc.template != "" {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, found := extractToolCallsFormat(tmpl)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got text %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
if found != tc.found {
|
||||||
|
t.Errorf("got found %v, want %v", found, tc.found)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolPrefix(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic tool call with action prefix",
|
||||||
|
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||||
|
want: "Action: ```json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete functools bracket",
|
||||||
|
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||||
|
want: "functools[",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with angle brackets",
|
||||||
|
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||||
|
want: "Hello, world! <tool_call>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool call formats",
|
||||||
|
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||||
|
want: "[tool_call] <tool_call>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single angle bracket tool call",
|
||||||
|
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||||
|
want: "<tool_call>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete angle bracket after tool call",
|
||||||
|
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||||
|
want: "[tool_call] <",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "angle bracket prefix with tool call",
|
||||||
|
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||||
|
want: "> <tool_call>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase tool call with incomplete bracket",
|
||||||
|
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||||
|
want: "[TOOL_CALL] [",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase tool call with adjacent bracket",
|
||||||
|
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||||
|
want: "[TOOL_CALL][",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with pipe delimiters",
|
||||||
|
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||||
|
want: "<|tool_call|>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool with no prefix",
|
||||||
|
template: "{{if .ToolCalls}}{{end}}",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
got := toolPrefix(tmpl)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolTemplate(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic tool call range",
|
||||||
|
template: "{{range .ToolCalls}}test{{end}}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tool calls",
|
||||||
|
template: "{{range .Other}}test{{end}}",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested tool calls",
|
||||||
|
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty template",
|
||||||
|
template: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls in if statement",
|
||||||
|
template: "{{if .ToolCalls}}test{{end}}",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := template.Parse(tmpl.Root.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = toolTemplate(parsed)
|
||||||
|
if err != nil && tt.want {
|
||||||
|
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSuffixOverlap(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
d string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no overlap",
|
||||||
|
s: "hello world",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full overlap",
|
||||||
|
s: "<tool_call>",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial overlap",
|
||||||
|
s: "text <tool_call>",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delimiter longer than string",
|
||||||
|
s: "<tool>",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
s: "",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty delimiter",
|
||||||
|
s: "<tool_call>",
|
||||||
|
d: "",
|
||||||
|
want: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single char overlap",
|
||||||
|
s: "test<",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial tool call",
|
||||||
|
s: "hello <tool_",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := suffixOverlap(tt.s, tt.d)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractToolArgs(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic tool call with text after",
|
||||||
|
template: `{{if .ToolCalls}}tool response{{end}}`,
|
||||||
|
want: "tool response",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with mixed content after",
|
||||||
|
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
|
||||||
|
want: "<tool_call>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with no text after",
|
||||||
|
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested tool call",
|
||||||
|
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
|
||||||
|
want: "[TOOL_CALL]",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tool calls",
|
||||||
|
template: `{{if .Something}}no tools here{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty template",
|
||||||
|
template: ``,
|
||||||
|
want: "",
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls sections",
|
||||||
|
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
|
||||||
|
want: "first",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "range over tool calls",
|
||||||
|
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls with pipe delimiters",
|
||||||
|
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
|
||||||
|
want: "<|tool|>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls with nested template",
|
||||||
|
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls with whitespace variations",
|
||||||
|
template: `{{if .ToolCalls}} tool {{end}}`,
|
||||||
|
want: " tool ",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := extractToolCallsFormat(tmpl)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollect(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
obj any
|
||||||
|
want []map[string]any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple map",
|
||||||
|
obj: map[string]any{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"key": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested map",
|
||||||
|
obj: map[string]any{
|
||||||
|
"outer": map[string]any{
|
||||||
|
"inner": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"outer": map[string]any{"inner": "value"}},
|
||||||
|
{"inner": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array of maps",
|
||||||
|
obj: []any{
|
||||||
|
map[string]any{"key1": "val1"},
|
||||||
|
map[string]any{"key2": "val2"},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"key1": "val1"},
|
||||||
|
{"key2": "val2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deeply nested",
|
||||||
|
obj: map[string]any{
|
||||||
|
"l1": map[string]any{
|
||||||
|
"l2": map[string]any{
|
||||||
|
"l3": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
|
||||||
|
{"l2": map[string]any{"l3": "value"}},
|
||||||
|
{"l3": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-map value",
|
||||||
|
obj: "string",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := collect(tt.obj)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare each map in the result
|
||||||
|
for i := range tt.want {
|
||||||
|
if !mapsEqual(got[i], tt.want[i]) {
|
||||||
|
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapsEqual compares two maps for deep equality
|
||||||
|
func mapsEqual(m1, m2 map[string]any) bool {
|
||||||
|
if len(m1) != len(m2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, v1 := range m1 {
|
||||||
|
v2, ok := m2[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch val1 := v1.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
val2, ok := v2.(map[string]any)
|
||||||
|
if !ok || !mapsEqual(val1, val2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if v1 != v2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
Loading…
Reference in New Issue