Merge branch 'ollama:main' into mmap

This commit is contained in:
frob 2025-09-06 12:42:35 +02:00 committed by GitHub
commit c542b9dd8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 1783 additions and 566 deletions

View File

@ -541,6 +541,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama) - [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/)) - [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/)) - [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
### Mobile ### Mobile
@ -601,6 +602,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama) - [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies) - [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases) - [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
### Supported backends ### Supported backends

View File

@ -286,16 +286,23 @@ func mapToTypeScriptType(jsonType string) string {
} }
} }
type ToolFunction struct { type ToolFunctionParameters struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"` Type string `json:"type"`
Defs any `json:"$defs,omitempty"` Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
Required []string `json:"required"` Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"` Properties map[string]ToolProperty `json:"properties"`
} `json:"parameters"` }
func (t *ToolFunctionParameters) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolFunctionParameters `json:"parameters"`
} }
func (t *ToolFunction) String() string { func (t *ToolFunction) String() string {
@ -881,7 +888,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if t < 0 { if t < 0 {
d.Duration = time.Duration(math.MaxInt64) d.Duration = time.Duration(math.MaxInt64)
} else { } else {
d.Duration = time.Duration(int(t) * int(time.Second)) d.Duration = time.Duration(t * float64(time.Second))
} }
case string: case string:
d.Duration, err = time.ParseDuration(t) d.Duration, err = time.ParseDuration(t)

View File

@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req string req string
exp *Duration exp *Duration
}{ }{
{
name: "Unset",
req: `{ }`,
exp: nil,
},
{ {
name: "Positive Integer", name: "Positive Integer",
req: `{ "keep_alive": 42 }`, req: `{ "keep_alive": 42 }`,
@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
{ {
name: "Positive Float", name: "Positive Float",
req: `{ "keep_alive": 42.5 }`, req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42 * time.Second}, exp: &Duration{42500 * time.Millisecond},
}, },
{ {
name: "Positive Integer String", name: "Positive Integer String",
@ -436,3 +441,50 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
}) })
} }
} }
func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct {
name string
params ToolFunctionParameters
expected string
}{
{
name: "simple object with string property",
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
{
name: "marshal failure returns empty string",
params: ToolFunctionParameters{
Type: "object",
Defs: func() any {
// Create a cycle that will cause json.Marshal to fail
type selfRef struct {
Self *selfRef
}
s := &selfRef{}
s.Self = s
return s
}(),
Properties: map[string]ToolProperty{},
},
expected: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.params.String()
assert.Equal(t, test.expected, result)
})
}
}

View File

@ -16,17 +16,22 @@ import (
type gptossModel struct { type gptossModel struct {
ModelParameters ModelParameters
HiddenLayers uint32 `json:"num_hidden_layers"` HiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"` HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"` IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"num_attention_heads"` AttentionHeads uint32 `json:"num_attention_heads"`
KeyValueHeads uint32 `json:"num_key_value_heads"` KeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
Experts uint32 `json:"num_experts"` Experts uint32 `json:"num_experts"`
LocalExperts uint32 `json:"num_local_experts"`
ExpertsPerToken uint32 `json:"experts_per_token"` ExpertsPerToken uint32 `json:"experts_per_token"`
RMSNormEpsilon float32 `json:"rms_norm_eps"` RMSNormEpsilon float32 `json:"rms_norm_eps"`
InitialContextLength uint32 `json:"initial_context_length"` InitialContextLength uint32 `json:"initial_context_length"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
RopeScalingFactor float32 `json:"rope_scaling_factor"` RopeScalingFactor float32 `json:"rope_scaling_factor"`
RopeScaling struct {
Factor float32 `json:"factor"`
} `json:"rope_scaling"`
SlidingWindow uint32 `json:"sliding_window"` SlidingWindow uint32 `json:"sliding_window"`
} }
@ -36,11 +41,11 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t) kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss" kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4) kv["general.file_type"] = uint32(4)
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength)) kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
kv["gptoss.block_count"] = m.HiddenLayers kv["gptoss.block_count"] = m.HiddenLayers
kv["gptoss.embedding_length"] = m.HiddenSize kv["gptoss.embedding_length"] = m.HiddenSize
kv["gptoss.feed_forward_length"] = m.IntermediateSize kv["gptoss.feed_forward_length"] = m.IntermediateSize
kv["gptoss.expert_count"] = m.Experts kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
kv["gptoss.expert_used_count"] = m.ExpertsPerToken kv["gptoss.expert_used_count"] = m.ExpertsPerToken
kv["gptoss.attention.head_count"] = m.AttentionHeads kv["gptoss.attention.head_count"] = m.AttentionHeads
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
@ -49,7 +54,7 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5) kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
kv["gptoss.attention.sliding_window"] = m.SlidingWindow kv["gptoss.attention.sliding_window"] = m.SlidingWindow
kv["gptoss.rope.freq_base"] = m.RopeTheta kv["gptoss.rope.freq_base"] = m.RopeTheta
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|> kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
kv["tokenizer.ggml.add_bos_token"] = false kv["tokenizer.ggml.add_bos_token"] = false
@ -92,6 +97,11 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s { for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape() dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name += ".weight"
}
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: name, Name: name,
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
@ -104,7 +114,27 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
} }
func (m *gptossModel) Replacements() []string { func (m *gptossModel) Replacements() []string {
return []string{ var replacements []string
if m.MaxPositionEmbeddings > 0 {
// hf flavored model
replacements = []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
"self_attn.sinks", "attn_sinks",
"post_attention_layernorm", "ffn_norm",
"mlp.router", "ffn_gate_inp",
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
"mlp.experts.down_proj_", "ffn_down_exps.",
"model.norm", "output_norm",
}
} else {
replacements = []string{
// noop replacements so other replacements will not be applied // noop replacements so other replacements will not be applied
".blocks", ".blocks", ".blocks", ".blocks",
".scales", ".scales", ".scales", ".scales",
@ -123,6 +153,8 @@ func (m *gptossModel) Replacements() []string {
"unembedding", "output", "unembedding", "output",
"scale", "weight", "scale", "weight",
} }
}
return replacements
} }
type mxfp4 struct { type mxfp4 struct {
@ -140,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
blocksDims[i] = int(d) blocksDims[i] = int(d)
} }
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes())) bts := b.Bytes()
var tmp [16]byte
for i := 0; i < b.Len(); i += 16 {
for j := range 8 {
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[i+j], bts[i+j+8]
tmp[2*j+0] = (a & 0x0F) | (b << 4)
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
}
copy(bts[i:i+16], tmp[:])
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
var s bytes.Buffer var s bytes.Buffer
if _, err := m.scales.WriteTo(&s); err != nil { if _, err := m.scales.WriteTo(&s); err != nil {
@ -174,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err return 0, err
} }
return 0, nil return int64(len(u8s)), nil
} }

View File

@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
const ( const (
tensorKindFP32 uint32 = iota tensorKindFP32 uint32 = iota
tensorKindFP16 tensorKindFP16
tensorKindMXFP4 = 4
tensorKindBF16 = 30 tensorKindBF16 = 30
tensorKindMXFP4 = 39
) )
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {

View File

@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
switch st.Kind() { switch st.Kind() {
case tensorKindFP32: case tensorKindFP32:
return 0, binary.Write(w, binary.LittleEndian, f32s) return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
case tensorKindFP16: case tensorKindFP16:
f16s := make([]uint16, len(f32s)) f16s := make([]uint16, len(f32s))
for i := range f32s { for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits() f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
} }
return 0, binary.Write(w, binary.LittleEndian, f16s) return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
case tensorKindBF16: case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s) u8s := bfloat16.EncodeFloat32(f32s)
return 0, binary.Write(w, binary.LittleEndian, u8s) return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
default: default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
} }

View File

@ -277,6 +277,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
FreeMemory: (totalMemory - usedMemory), FreeMemory: (totalMemory - usedMemory),
}, },
ID: ID, ID: ID,
filterID: gpuOrdinalID,
Name: name, Name: name,
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
MinimumMemory: rocmMinimumMemory, MinimumMemory: rocmMinimumMemory,
@ -394,7 +395,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
// Check for env var workarounds // Check for env var workarounds
if name == "1002:687f" { // Vega RX 56 if name == "1002:687f" { // Vega RX 56
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"}) gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
} }
// The GPU has passed all the verification steps and is supported // The GPU has passed all the verification steps and is supported
@ -523,19 +524,26 @@ func verifyKFDDriverAccess() error {
return nil return nil
} }
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
ids := []string{} ids := []string{}
for _, info := range gpuInfo { for _, info := range gpuInfo {
if info.Library != "rocm" { if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue continue
} }
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
if _, err := strconv.Atoi(info.ID); err == nil {
ids = append(ids, fmt.Sprintf("%d", info.filterID))
} else {
ids = append(ids, info.ID) ids = append(ids, info.ID)
} }
}
if len(ids) == 0 {
return ""
}
// There are 3 potential env vars to use to select GPUs. // There are 3 potential env vars to use to select GPUs.
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux // ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
// GPU_DEVICE_ORDINAL supports numeric IDs only // GPU_DEVICE_ORDINAL supports numeric IDs only
// HIP_VISIBLE_DEVICES supports numeric IDs only // HIP_VISIBLE_DEVICES supports numeric IDs only
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",") return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
} }

View File

@ -111,6 +111,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
UnreliableFreeMemory: true, UnreliableFreeMemory: true,
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
filterID: i,
DependencyPath: []string{libDir}, DependencyPath: []string{libDir},
MinimumMemory: rocmMinimumMemory, MinimumMemory: rocmMinimumMemory,
Name: name, Name: name,
@ -200,19 +201,26 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
return nil return nil
} }
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
ids := []string{} ids := []string{}
for _, info := range gpuInfo { for _, info := range gpuInfo {
if info.Library != "rocm" { if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue continue
} }
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
if _, err := strconv.Atoi(info.ID); err == nil {
ids = append(ids, fmt.Sprintf("%d", info.filterID))
} else {
ids = append(ids, info.ID) ids = append(ids, info.ID)
} }
}
if len(ids) == 0 {
return ""
}
// There are 3 potential env vars to use to select GPUs. // There are 3 potential env vars to use to select GPUs.
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
// HIP_VISIBLE_DEVICES supports numeric IDs only // HIP_VISIBLE_DEVICES supports numeric IDs only
// GPU_DEVICE_ORDINAL supports numeric IDs only // GPU_DEVICE_ORDINAL supports numeric IDs only
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",") return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
} }

View File

@ -16,19 +16,6 @@ import (
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK") var CudaTegra string = os.Getenv("JETSON_JETPACK")
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}
func cudaVariant(gpuInfo CudaGPUInfo) string { func cudaVariant(gpuInfo CudaGPUInfo) string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" { if CudaTegra != "" {

View File

@ -371,6 +371,15 @@ func GetGPUInfo() GpuInfoList {
} }
rocmGPUs, err = AMDGetGPUInfo() rocmGPUs, err = AMDGetGPUInfo()
// The ID field is used in context of the filtered set of GPUS
// so we have to replace any of these numeric IDs with their
// placement in this set of GPUs
for i := range rocmGPUs {
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
rocmGPUs[i].ID = strconv.Itoa(i)
}
}
if err != nil { if err != nil {
bootstrapErrors = append(bootstrapErrors, err) bootstrapErrors = append(bootstrapErrors, err)
} }
@ -680,23 +689,16 @@ func getVerboseState() C.uint16_t {
// Given the list of GPUs this instantiation is targeted for, // Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable // figure out the visible devices environment variable
// func (l GpuInfoList) GetVisibleDevicesEnv() []string {
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 { if len(l) == 0 {
return "", "" return nil
} }
switch l[0].Library { vd := []string{}
case "cuda": // Only filter the AMD GPUs at this level, let all NVIDIA devices through
return cudaGetVisibleDevicesEnv(l) if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
case "rocm": vd = append(vd, tmp)
return rocmGetVisibleDevicesEnv(l)
case "oneapi":
return oneapiGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
} }
return vd
} }
func GetSystemInfo() SystemInfo { func GetSystemInfo() SystemInfo {

View File

@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
}, nil }, nil
} }
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { func (l GpuInfoList) GetVisibleDevicesEnv() []string {
// No-op on darwin // No-op on darwin
return "", "" return nil
} }
func GetSystemInfo() SystemInfo { func GetSystemInfo() SystemInfo {

View File

@ -1,21 +0,0 @@
//go:build linux || windows
package discover
import (
"log/slog"
"strings"
)
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "oneapi" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
}

View File

@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath []string `json:"lib_path,omitempty"` DependencyPath []string `json:"lib_path,omitempty"`
// Extra environment variables specific to the GPU as list of [key,value] // Extra environment variables specific to the GPU as list of [key=value]
EnvWorkarounds [][2]string `json:"envs,omitempty"` EnvWorkarounds []string `json:"envs,omitempty"`
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates // Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
// the FreeMemory is best effort, and may over or under report actual memory usage // the FreeMemory is best effort, and may over or under report actual memory usage
@ -37,6 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
// GPU information // GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU ID string `json:"gpu_id"` // string to use for selection of this specific GPU
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available Name string `json:"name"` // user friendly name if available
Compute string `json:"compute"` // Compute Capability or gfx Compute string `json:"compute"` // Compute Capability or gfx

View File

@ -7,9 +7,11 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"slices" "slices"
"strings" "strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil" "github.com/ollama/ollama/fs/util/bufioutil"
) )
@ -275,7 +277,7 @@ type Tensor struct {
func (t Tensor) block() (n int) { func (t Tensor) block() (n int) {
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil { if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
return -1 return math.MaxInt
} }
return return
@ -288,24 +290,24 @@ func (t Tensor) blockSize() uint64 {
func (t TensorType) BlockSize() uint64 { func (t TensorType) BlockSize() uint64 {
switch t { switch t {
case case
0, // F32 TensorTypeF32,
1, // F16 TensorTypeF16,
24, // I8 TensorTypeI8,
25, // I16 TensorTypeI16,
26, // I32 TensorTypeI32,
27, // I64 TensorTypeI64,
28, // F64 TensorTypeF64,
30: // BF16 TensorTypeBF16:
return 1 return 1
case case
2, // Q4_0 TensorTypeQ4_0,
3, // Q4_1 TensorTypeQ4_1,
4, // MXFP4 TensorTypeQ5_0,
6, // Q5_0 TensorTypeQ5_1,
7, // Q5_1 TensorTypeQ8_0,
8, // Q8_0 TensorTypeQ8_1,
9, // Q8_1 tensorTypeIQ4_NL,
20: // IQ4_NL 4, TensorTypeMXFP4:
return 32 return 32
default: default:
return 256 return 256
@ -328,8 +330,6 @@ func (t TensorType) TypeSize() uint64 {
return 2 + blockSize/2 return 2 + blockSize/2
case TensorTypeQ4_1: case TensorTypeQ4_1:
return 2 + 2 + blockSize/2 return 2 + 2 + blockSize/2
case TensorTypeMXFP4, 39:
return 1 + blockSize/2
case TensorTypeQ5_0: case TensorTypeQ5_0:
return 2 + 4 + blockSize/2 return 2 + 4 + blockSize/2
case TensorTypeQ5_1: case TensorTypeQ5_1:
@ -380,6 +380,8 @@ func (t TensorType) TypeSize() uint64 {
return blockSize/8 + blockSize/16 + blockSize/32 return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16: case TensorTypeBF16:
return 2 return 2
case 4, TensorTypeMXFP4:
return 1 + blockSize/2
default: default:
return 0 return 0
} }
@ -479,7 +481,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil }, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel) context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
@ -677,7 +679,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
kv[i] *= context kv[i] *= context
} }
} }
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
} }
return return
@ -773,6 +780,13 @@ func (f GGML) SupportsFlashAttention() bool {
return headCountK != 0 && headCountV != 0 && headCountK == headCountV return headCountK != 0 && headCountV != 0 && headCountK == headCountV
} }
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gptoss", "gpt-oss",
}, f.KV().String("general.architecture"))
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 { func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType { switch cacheType {

View File

@ -533,12 +533,15 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
} }
} }
slices.SortStableFunc(ts, func(a, b *Tensor) int { slices.SortStableFunc(
if i, j := a.block(), b.block(); i > 0 && j > 0 { ts,
return cmp.Compare(i, j) func(a, b *Tensor) int {
} return cmp.Or(
return cmp.Compare(a.Name, b.Name) cmp.Compare(a.block(), b.block()),
}) cmp.Compare(a.Name, b.Name),
)
},
)
var s uint64 var s uint64
for i := range ts { for i := range ts {

View File

@ -11,24 +11,24 @@ import (
) )
func TestWriteGGUF(t *testing.T) { func TestWriteGGUF(t *testing.T) {
r := rand.New(rand.NewPCG(0, 0)) b := bytes.NewBuffer(make([]byte, 2*3))
for range 8 { for range 8 {
t.Run("shuffle", func(t *testing.T) { t.Run("shuffle", func(t *testing.T) {
t.Parallel() t.Parallel()
ts := []*Tensor{ ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, {Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, {Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
} }
r.Shuffle(len(ts), func(i, j int) { rand.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i] ts[i], ts[j] = ts[j], ts[i]
}) })
@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
} }
if diff := cmp.Diff(Tensors{ if diff := cmp.Diff(Tensors{
Offset: 608, Offset: 592,
items: []*Tensor{ items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}}, {Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}}, {Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}}, {Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}}, {Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}}, {Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}}, {Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}}, {Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}}, {Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}}, {Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},

View File

@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ4_0 return TensorTypeQ4_0
case fileTypeQ4_1: case fileTypeQ4_1:
return TensorTypeQ4_1 return TensorTypeQ4_1
case fileTypeMXFP4:
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
case FileTypeQ8_0: case FileTypeQ8_0:
return TensorTypeQ8_0 return TensorTypeQ8_0
case fileTypeQ5_0: case fileTypeQ5_0:
@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ2_K return TensorTypeQ2_K
case FileTypeBF16: case FileTypeBF16:
return TensorTypeBF16 return TensorTypeBF16
case fileTypeMXFP4:
return TensorTypeMXFP4
default: default:
slog.Warn("unsupported file type", "type", ftype) slog.Warn("unsupported file type", "type", ftype)
return 0 // F32 return 0 // F32
@ -191,7 +191,7 @@ const (
TensorTypeF16 TensorTypeF16
TensorTypeQ4_0 TensorTypeQ4_0
TensorTypeQ4_1 TensorTypeQ4_1
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0 TensorTypeQ5_0
TensorTypeQ5_1 TensorTypeQ5_1
@ -226,6 +226,7 @@ const (
tensorTypeIQ4_NL_4_4 // unused by GGML tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML tensorTypeIQ4_NL_8_8 // unused by GGML
TensorTypeMXFP4
) )
// ParseFileType parses the provided GGUF file type // ParseFileType parses the provided GGUF file type
@ -318,7 +319,7 @@ func (t TensorType) String() string {
return "F64" return "F64"
case TensorTypeBF16: case TensorTypeBF16:
return "BF16" return "BF16"
case TensorTypeMXFP4: case 4, TensorTypeMXFP4:
return "MXFP4" return "MXFP4"
default: default:
return "unknown" return "unknown"

View File

@ -1,10 +1,8 @@
package server package harmony
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"unicode" "unicode"
@ -20,18 +18,6 @@ const (
harmonyParserState_ParsingContent harmonyParserState_ParsingContent
) )
func shouldUseHarmony(model Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func (s harmonyParserState) String() string { func (s harmonyParserState) String() string {
switch s { switch s {
// we're looking for the message start tag // we're looking for the message start tag
@ -277,20 +263,20 @@ const (
// This is a higher level interface that maps harmony concepts into ollama concepts // This is a higher level interface that maps harmony concepts into ollama concepts
type HarmonyMessageHandler struct { type HarmonyMessageHandler struct {
state harmonyMessageState state harmonyMessageState
harmonyParser *HarmonyParser HarmonyParser *HarmonyParser
functionNameMap *FunctionNameMap FunctionNameMap *FunctionNameMap
} }
// NewHarmonyMessageHandler creates a new message handler // NewHarmonyMessageHandler creates a new message handler
func NewHarmonyMessageHandler() *HarmonyMessageHandler { func NewHarmonyMessageHandler() *HarmonyMessageHandler {
return &HarmonyMessageHandler{ return &HarmonyMessageHandler{
state: harmonyMessageState_Normal, state: harmonyMessageState_Normal,
harmonyParser: &HarmonyParser{ HarmonyParser: &HarmonyParser{
MessageStartTag: "<|start|>", MessageStartTag: "<|start|>",
MessageEndTag: "<|end|>", MessageEndTag: "<|end|>",
HeaderEndTag: "<|message|>", HeaderEndTag: "<|message|>",
}, },
functionNameMap: NewFunctionNameMap(), FunctionNameMap: NewFunctionNameMap(),
} }
} }
@ -301,11 +287,11 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
thinkingSb := strings.Builder{} thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{} toolContentSb := strings.Builder{}
events := h.harmonyParser.AddContent(content) events := h.HarmonyParser.AddContent(content)
for _, event := range events { for _, event := range events {
switch event := event.(type) { switch event := event.(type) {
case HarmonyEventHeaderComplete: case HarmonyEventHeaderComplete:
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header) logutil.Trace("harmony event header complete", "header", event.Header)
switch event.Header.Channel { switch event.Header.Channel {
case "analysis": case "analysis":
if event.Header.Recipient != "" { if event.Header.Recipient != "" {
@ -328,7 +314,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
h.state = harmonyMessageState_Normal h.state = harmonyMessageState_Normal
} }
case HarmonyEventContentEmitted: case HarmonyEventContentEmitted:
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state) logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
if h.state == harmonyMessageState_Normal { if h.state == harmonyMessageState_Normal {
contentSb.WriteString(event.Content) contentSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_Thinking { } else if h.state == harmonyMessageState_Thinking {

View File

@ -1,4 +1,4 @@
package server package harmony
import ( import (
"fmt" "fmt"

View File

@ -2,10 +2,13 @@
This directory contains integration tests to exercise Ollama end-to-end to verify behavior This directory contains integration tests to exercise Ollama end-to-end to verify behavior
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
The integration tests have 2 modes of operating. The integration tests have 2 modes of operating.
1. By default, they will start the server on a random port, run the tests, and then shutdown the server. 1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote 2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
> [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.

View File

@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
req := api.EmbeddingRequest{ req := api.EmbeddingRequest{
Model: "orca-mini", Model: libraryEmbedModels[0],
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestBlueSky(t *testing.T) { func TestBlueSky(t *testing.T) {
@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test // DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Prompt: "天空为什么是蓝色的?", Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
} }
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second) t.Fatal(err)
}
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
DoGenerate(ctx, t, client, req, []string{
"散射", // scattering
"频率", // frequency
}, 120*time.Second, 120*time.Second)
} }
func TestExtendedUnicodeOutput(t *testing.T) { func TestExtendedUnicodeOutput(t *testing.T) {
@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
} }
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second) DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
} }
@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
} }
modelDir, err := os.MkdirTemp("", "ollama_埃") modelDir, err := os.MkdirTemp("", "ollama_埃")
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(modelDir) defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir)

View File

@ -14,8 +14,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
@ -79,21 +77,21 @@ func TestMultiModelStress(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// All models compatible with ollama-engine
smallModels := []string{ smallModels := []string{
"llama3.2:1b", "llama3.2:1b",
"qwen3:0.6b", "qwen3:0.6b",
"gemma:2b", "gemma2:2b",
"deepseek-r1:1.5b", "deepseek-r1:1.5b", // qwen2 arch
"starcoder2:3b", "gemma3:270m",
} }
mediumModels := []string{ mediumModels := []string{
"qwen3:8b", "llama3.2:3b", // ~3.4G
"llama2", "qwen3:8b", // ~6.6G
"deepseek-r1:7b", "gpt-oss:20b", // ~15G
"mistral", "deepseek-r1:7b", // ~5.6G
"dolphin-mistral", "gemma3:4b", // ~5.8G
"gemma:7b", "gemma2:9b", // ~8.1G
"codellama:7b",
} }
var chosenModels []string var chosenModels []string
@ -114,7 +112,9 @@ func TestMultiModelStress(t *testing.T) {
// Make sure all the models are pulled before we get started // Make sure all the models are pulled before we get started
for _, model := range chosenModels { for _, model := range chosenModels {
require.NoError(t, PullIfMissing(ctx, client, model)) if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatal(err)
}
} }
// Determine how many models we can load in parallel before we exceed VRAM // Determine how many models we can load in parallel before we exceed VRAM

View File

@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "llama2", Model: smol,
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
} }
func TestContextExhaustion(t *testing.T) { func TestContextExhaustion(t *testing.T) {
@ -49,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "llama2", Model: smol,
Prompt: "Write me a story with a ton of emojis?", Prompt: "Write me a story with a ton of emojis?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@ -63,10 +63,10 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
} }
// Send multiple requests with prior context and ensure the response is coherant and expected // Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) { func TestGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests() req, resp := GenerateRequests()
@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
}(i) }(i)
} }
wg.Wait() wg.Wait()
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial empty request
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
wg.Add(numParallel)
for i := range numParallel {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
req[k].Messages = append(req[k].Messages,
*assistant,
api.Message{Role: "user", Content: "tell me more!"},
)
}
}(i)
}
wg.Wait()
} }

View File

@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestVisionModels(t *testing.T) { func TestVisionModels(t *testing.T) {
@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
for _, v := range testCases { for _, v := range testCases {
t.Run(v.model, func(t *testing.T) { t.Run(v.model, func(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding) image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: v.model, Model: v.model,
Prompt: "what does the text in this image say?", Prompt: "what does the text in this image say?",
@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
// Note: sometimes it returns "the ollamas" sometimes "the ollams" // Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam" resp := "the ollam"
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// llava models on CPU can be quite slow to start // llava models on CPU can be quite slow to start
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}) })
@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
func TestIntegrationSplitBatch(t *testing.T) { func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6) skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding) image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "gemma3:4b", Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one // Fill up a chunk of the batch so the image will partially spill over into the next one
@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// llava models on CPU can be quite slow to start, // llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
} }

View File

@ -1,47 +0,0 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies
var (
stream = false
req = [2]api.GenerateRequest{
{
Model: smol,
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
{"sunlight", "scattering", "interact"},
{"england", "english", "massachusetts", "pilgrims"},
}
)
func TestIntegrationSimple(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, req[0], resp[0])
}

View File

@ -13,12 +13,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
func TestMaxQueue(t *testing.T) { func TestMaxQueue(t *testing.T) {
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
if os.Getenv("OLLAMA_TEST_EXISTING") != "" { if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size") t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
return return
@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// Context for the worker threads so we can shut them down // Context for the worker threads so we can shut them down
// embedCtx, embedCancel := context.WithCancel(ctx) // embedCtx, embedCancel := context.WithCancel(ctx)
@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
switch { switch {
case genErr == nil: case genErr == nil:
successCount++ successCount++
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
}
case errors.Is(genErr, context.Canceled): case errors.Is(genErr, context.Canceled):
canceledCount++ canceledCount++
case strings.Contains(genErr.Error(), "busy"): case strings.Contains(genErr.Error(), "busy"):
@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
case strings.Contains(genErr.Error(), "connection reset by peer"): case strings.Contains(genErr.Error(), "connection reset by peer"):
resetByPeerCount++ resetByPeerCount++
default: default:
require.NoError(t, genErr, "%d request failed", i) if genErr != nil {
t.Fatalf("%d request failed", i)
}
} }
slog.Info("embed finished", "id", i) slog.Info("embed finished", "id", i)
@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
embedwg.Wait() embedwg.Wait()
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount) slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?") if resetByPeerCount != 0 {
require.True(t, busyCount > 0, "no requests hit busy error but some should have") t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout") }
if busyCount == 0 {
t.Fatalf("no requests hit busy error but some should have")
}
if canceledCount > 0 {
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
}
} }

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@ -25,11 +26,11 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/stretchr/testify/require"
) )
var ( var (
smol = "llama3.2:1b" smol = "llama3.2:1b"
stream = false
) )
var ( var (
@ -435,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
} }
lifecycle.ServerLogFile = fp.Name() lifecycle.ServerLogFile = fp.Name()
fp.Close() fp.Close()
require.NoError(t, startServer(t, ctx, testEndpoint)) if err := startServer(t, ctx, testEndpoint); err != nil {
t.Fatal(err)
}
} }
return client, testEndpoint, func() { return client, testEndpoint, func() {
@ -468,7 +471,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model)) if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
t.Fatal(err)
}
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
} }
@ -509,7 +514,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr) slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
return context return context
} }
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) if genErr != nil {
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
}
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()
atLeastOne := false atLeastOne := false
@ -519,7 +526,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
break break
} }
} }
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response) if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for generate") t.Error("outer test context done while waiting for generate")
@ -561,17 +570,97 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{ [][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"}, {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"}, {"nitrogen", "oxygen", "carbon", "dioxide"},
} }
} }
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
role := "assistant"
fn := func(response api.ChatResponse) error {
// fmt.Print(".")
role = response.Message.Role
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return errors.New("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
return nil
}
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
return &api.Message{Role: role, Content: buf.String()}
}
func ChatRequests() ([]api.ChatRequest, [][]string) {
genReqs, results := GenerateRequests()
reqs := make([]api.ChatRequest, len(genReqs))
// think := api.ThinkValue{Value: "low"}
for i := range reqs {
reqs[i].Model = genReqs[i].Model
reqs[i].Stream = genReqs[i].Stream
reqs[i].KeepAlive = genReqs[i].KeepAlive
// reqs[i].Think = &think
reqs[i].Messages = []api.Message{
{
Role: "user",
Content: genReqs[i].Prompt,
},
}
}
return reqs, results
}
func skipUnderMinVRAM(t *testing.T, gb uint64) { func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future // TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64) maxVram, err := strconv.ParseUint(s, 10, 64)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
// Don't hammer on small VRAM cards... // Don't hammer on small VRAM cards...
if maxVram < gb*format.GibiByte { if maxVram < gb*format.GibiByte {
t.Skip("skipping with small VRAM to avoid timeouts") t.Skip("skipping with small VRAM to avoid timeouts")
@ -579,6 +668,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
} }
} }
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
}
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
if m.Name != model {
continue
}
gpuPercent := 0
switch {
case m.SizeVRAM == 0:
gpuPercent = 0
case m.SizeVRAM == m.Size:
gpuPercent = 100
case m.SizeVRAM > m.Size || m.Size == 0:
t.Logf("unexpected size detected: %d", m.SizeVRAM)
default:
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
gpuPercent = int(100 - cpuPercent)
}
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
return
}
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) { func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline() deadline, hasDeadline := t.Deadline()
if !hasDeadline { if !hasDeadline {

View File

@ -0,0 +1,130 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Wed, 27 Aug 2025 14:39:48 -0700
Subject: [PATCH] ggml: Enable resetting backend devices
Touching a CUDA device causes the allocation of a primary context
with CUDA data structures (~300 MB of VRAM). If a device is
unused then it can be reset to free these data structures.
---
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-backend-impl.h | 4 ++++
ggml/src/ggml-backend.cpp | 8 ++++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
ggml/src/ggml-cuda/vendors/hip.h | 1 +
5 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index b602a7c78..fda5ceb24 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -167,6 +167,7 @@ extern "C" {
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 81749a5a3..6f10c353b 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -178,6 +178,10 @@ extern "C" {
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
+
+ // (optional) reset device, clearing existing allocations and context
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
+ void (*reset)(ggml_backend_dev_t dev);
};
struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 05a842ed5..6556943b0 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
return device->iface.init_backend(device, params);
}
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
+ if (device->iface.reset == NULL) {
+ return;
+ }
+
+ device->iface.reset(device);
+}
+
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
return device->iface.get_buffer_type(device);
}
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c7f9dc3a5..e43fde523 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
return id;
}
+void ggml_cuda_reset_device(int device) {
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaDeviceReset());
+}
+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
+ props->memory_total = props->memory_free = 0;
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ ggml_cuda_reset_device(ctx->device);
+}
+
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
+ /* .reset = */ ggml_backend_cuda_device_reset,
};
// backend reg
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
- ggml_cuda_set_device(i);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index c31f31923..cf22e60d2 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -40,6 +40,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceReset hipDeviceReset
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled

View File

@ -0,0 +1,28 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri, 29 Aug 2025 16:53:08 -0700
Subject: [PATCH] harden uncaught exception registration
---
ggml/src/ggml.cpp | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
index 0d388d45..f5bcb446 100644
--- a/ggml/src/ggml.cpp
+++ b/ggml/src/ggml.cpp
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
return false;
}
const auto prev{std::get_terminate()};
- GGML_ASSERT(prev != ggml_uncaught_exception);
- previous_terminate_handler = prev;
+ // GGML_ASSERT(prev != ggml_uncaught_exception);
+ if (prev != ggml_uncaught_exception) {
+ previous_terminate_handler = prev;
+ } else {
+ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
+ }
std::set_terminate(ggml_uncaught_exception);
return true;
}();

View File

@ -30,7 +30,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
// Try to pack into as few GPUs as possible, starting from 1 GPU // Try to pack into as few GPUs as possible, starting from 1 GPU
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ { for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
gpuSubset := sgl[:numGPUs] gpuSubset := sgl[:numGPUs]
ok, estimatedVRAM := PredictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel) ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
if ok { if ok {
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading", slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
@ -48,7 +48,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
// - try subsets of GPUs instead of just falling back to 1 or all in a family // - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set) // Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
if ok, estimatedVRAM := PredictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok { if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
slog.Info("new model will fit in available VRAM, loading", slog.Info("new model will fit in available VRAM, loading",
"model", modelPath, "model", modelPath,
"library", sgl[0].Library, "library", sgl[0].Library,
@ -71,7 +71,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
var bestEstimate uint64 var bestEstimate uint64
var bestFit int var bestFit int
for i, gl := range byLibrary { for i, gl := range byLibrary {
_, estimatedVRAM := PredictServerFit(gl, f, adapters, projectors, opts, numParallel) _, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
if estimatedVRAM > bestEstimate { if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM bestEstimate = estimatedVRAM
bestFit = i bestFit = i
@ -81,7 +81,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
} }
// This algorithm looks for a complete fit to determine if we need to unload other models // This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) { func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
// Split up the GPUs by type and try them // Split up the GPUs by type and try them
var estimatedVRAM uint64 var estimatedVRAM uint64
for _, gpus := range allGpus.ByLibrary() { for _, gpus := range allGpus.ByLibrary() {
@ -97,6 +97,10 @@ func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj
return true, estimatedVRAM return true, estimatedVRAM
} }
} }
if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory {
return true, estimatedVRAM
}
} }
return false, estimatedVRAM return false, estimatedVRAM
} }
@ -191,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
slog.Warn("model missing blk.0 layer size") slog.Warn("model missing blk.0 layer size")
} }
var kvct string useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
if envconfig.FlashAttention() &&
discover.GetGPUInfo().FlashAttentionSupported() && discover.GetGPUInfo().FlashAttentionSupported() &&
f.SupportsFlashAttention() { f.SupportsFlashAttention()
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType()) requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && f.SupportsKVCacheType(requested) { if requested != "" && f.SupportsKVCacheType(requested) {
kvct = requested kvct = requested
} }
} }
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct) kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
if len(kv) > 0 { if len(kv) > 0 {
layerSize += kv[0] layerSize += kv[0]

View File

@ -195,6 +195,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
// that can handle it. // that can handle it.
fa := envconfig.FlashAttention() fa := envconfig.FlashAttention()
if f.FlashAttention() {
slog.Info("model wants flash attention")
fa = true
}
if fa && !gpus.FlashAttentionSupported() { if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu") slog.Warn("flash attention enabled but not supported by gpu")
fa = false fa = false
@ -355,23 +360,28 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
envWorkarounds := [][2]string{} envWorkarounds := []string{}
for _, gpu := range gpus { for _, gpu := range gpus {
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...) envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
} }
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
envWorkarounds = append(envWorkarounds, gpus.GetVisibleDevicesEnv()...)
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path variable with our adjusted version // Update or add the path variable with our adjusted version
pathNeeded := true pathNeeded := true
envWorkaroundDone := make([]bool, len(envWorkarounds))
for i := range s.cmd.Env { for i := range s.cmd.Env {
cmp := strings.SplitN(s.cmd.Env[i], "=", 2) cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) { if strings.EqualFold(cmp[0], pathEnv) {
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false pathNeeded = false
} else if len(envWorkarounds) != 0 { } else if len(envWorkarounds) != 0 {
for _, kv := range envWorkarounds { for j, kv := range envWorkarounds {
if strings.EqualFold(cmp[0], kv[0]) { tmp := strings.SplitN(kv, "=", 2)
s.cmd.Env[i] = kv[0] + "=" + kv[1] if strings.EqualFold(cmp[0], tmp[0]) {
s.cmd.Env[i] = kv
envWorkaroundDone[j] = true
} }
} }
} }
@ -379,6 +389,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
if pathNeeded { if pathNeeded {
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
} }
for i, done := range envWorkaroundDone {
if !done {
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
}
}
slog.Info("starting runner", "cmd", s.cmd) slog.Info("starting runner", "cmd", s.cmd)
slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) slog.Debug("subprocess", "", filteredEnv(s.cmd.Env))
@ -492,6 +507,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
if !requireFull { if !requireFull {
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel) g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
} else { } else {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
return ErrLoadRequiredFull return ErrLoadRequiredFull
} }
} }
@ -524,10 +540,6 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
} }
} }
if requireFull && len(gpus) == 1 && gpus[0].Library == "cpu" && s.estimate.TotalSize > gpus[0].FreeMemory {
return ErrLoadRequiredFull
}
slog.Info("offload", "", s.estimate) slog.Info("offload", "", s.estimate)
s.gpus = gpus s.gpus = gpus
@ -666,8 +678,12 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
if !(len(gpus) == 1 && gpus[0].Library == "cpu") { if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
for _, gpu := range gpus { for _, gpu := range gpus {
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
available = 0
}
slog.Info("gpu memory", "id", gpu.ID, slog.Info("gpu memory", "id", gpu.ID,
"available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory), "available", format.HumanBytes2(available),
"free", format.HumanBytes2(gpu.FreeMemory), "free", format.HumanBytes2(gpu.FreeMemory),
"minimum", format.HumanBytes2(gpu.MinimumMemory), "minimum", format.HumanBytes2(gpu.MinimumMemory),
"overhead", format.HumanBytes2(envconfig.GpuOverhead())) "overhead", format.HumanBytes2(envconfig.GpuOverhead()))
@ -849,7 +865,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
} }
layers[i] += memory.CPU.Weights[i].Size layers[i] += memory.CPU.Weights[i].Size
layers[i] += memory.CPU.Cache[i].Size layers[i] += memory.CPU.Cache[i].Size
slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i])) logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
} }
gpuLayers := ml.GPULayersList{} gpuLayers := ml.GPULayersList{}

View File

@ -1,6 +1,7 @@
package logutil package logutil
import ( import (
"context"
"io" "io"
"log/slog" "log/slog"
"path/filepath" "path/filepath"
@ -27,3 +28,11 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
}, },
})) }))
} }
func Trace(msg string, args ...any) {
slog.Log(context.TODO(), LevelTrace, msg, args...)
}
func TraceContext(ctx context.Context, msg string, args ...any) {
slog.Log(ctx, LevelTrace, msg, args...)
}

View File

@ -266,7 +266,7 @@ func (m DeviceMemory) LogValue() slog.Value {
// allocation is guaranteed to be provided so that if it failed, the caller can // allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress. // accommodate that to make forward progress.
type BackendMemory struct { type BackendMemory struct {
// InputsWeights are always located on the CPU and cannot be moved // InputWeights are always located on the CPU and cannot be moved
InputWeights Memory InputWeights Memory
// CPU model components are located in system memory. This does not // CPU model components are located in system memory. This does not
@ -372,6 +372,7 @@ type Context interface {
Forward(...Tensor) Context Forward(...Tensor) Context
Compute(...Tensor) Compute(...Tensor)
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a // Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a // graph, simply preallocates memory. Typically called with a
@ -401,6 +402,8 @@ type Tensor interface {
Bytes() []byte Bytes() []byte
Floats() []float32 Floats() []float32
SetValueFromIntSlice(s []int32)
Neg(ctx Context) Tensor Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor

View File

@ -82,6 +82,7 @@ type Backend struct {
// to the name that is used by the model definition // to the name that is used by the model definition
tensorLoadTargets map[string][]string tensorLoadTargets map[string][]string
schedMu sync.Mutex // Only one Compute can run at a time
sched C.ggml_backend_sched_t sched C.ggml_backend_sched_t
schedBackends []C.ggml_backend_t schedBackends []C.ggml_backend_t
schedBufts []C.ggml_backend_buffer_type_t schedBufts []C.ggml_backend_buffer_type_t
@ -270,7 +271,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0]))) tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
C.ggml_set_name(tt, cname) C.ggml_set_name(tt, cname)
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
if layer == -1 { if layer == -1 {
@ -377,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
} }
for bs := range maps.Values(bbs) { for bs := range maps.Values(bbs) {
slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
} }
@ -535,6 +536,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
const BS = 17 // MXFP4 block size const BS = 17 // MXFP4 block size
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
var s uint64 var s uint64
var tmp [16]byte
for s < t.Size() { for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error // Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
@ -546,37 +548,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
return err return err
} }
for j := range n / BS { for j := range n / BS {
for i := 1; i < BS; i++ {
// swap nibbles
t_lo := bts[j*BS+i] & 0x0F
t_hi := bts[j*BS+i] & 0xF0
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
}
// transform aaaa...bbbb... to abababab...
oi := 0
tmp := [16]byte{}
for i := 1; i < 9; i++ { for i := 1; i < 9; i++ {
blk_a0 := bts[j*BS+i] & 0xF0 // transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
blk_a1 := bts[j*BS+i] << 4 a, b := bts[j*BS+i], bts[j*BS+i+8]
blk_b0 := bts[j*BS+i+8] >> 4 tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
blk_b1 := bts[j*BS+i+8] & 0x0F tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
// swap once more
out0 := blk_a0 | blk_b0
out1 := blk_a1 | blk_b1
out_h0 := out0 & 0xF0
out_l0 := out0 & 0x0F
out_h1 := out1 & 0xF0
out_l1 := out1 & 0x0F
out0 = (out_h0 >> 4) | (out_l0 << 4)
out1 = (out_h1 >> 4) | (out_l1 << 4)
tmp[oi] = out0
oi++
tmp[oi] = out1
oi++
}
for i := range tmp {
bts[j*BS+i+1] = tmp[i]
} }
copy(bts[j*BS+1:j*BS+17], tmp[:])
} }
for _, tt := range tts { for _, tt := range tts {
@ -652,6 +630,18 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
}) })
} }
// Cleanup any backend state from devices that we didn't end up using
nextDevice:
for _, d := range append(gpus, append(accels, cpus...)...) {
for _, backend := range b.schedBackends {
if d == C.ggml_backend_get_device(backend) {
continue nextDevice
}
}
C.ggml_backend_dev_reset(d)
}
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
@ -769,6 +759,15 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
} }
func (c *Context) Compute(tensors ...ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) {
c.ComputeWithNotify(nil, tensors...)
}
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
c.b.schedMu.Lock()
defer c.b.schedMu.Unlock()
if cb != nil {
go cb()
}
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS { if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status)) panic(fmt.Errorf("error computing ggml graph: %v", status))
} }
@ -812,7 +811,7 @@ func (c *Context) Reserve() {
} }
} }
slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size))) "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
} }
@ -1021,6 +1020,12 @@ func (t *Tensor) Floats() (data []float32) {
return return
} }
func (t *Tensor) SetValueFromIntSlice(s []int32) {
if len(s) > 0 {
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
}
}
func (t *Tensor) DType() ml.DType { func (t *Tensor) DType() ml.DType {
switch t.t._type { switch t.t._type {
case C.GGML_TYPE_F32: case C.GGML_TYPE_F32:

View File

@ -167,6 +167,7 @@ extern "C" {
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);

View File

@ -178,6 +178,10 @@ extern "C" {
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev); ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event); void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event); void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
// (optional) reset device, clearing existing allocations and context
// the caller must ensure that there are no outstanding buffers, as these will become invalid
void (*reset)(ggml_backend_dev_t dev);
}; };
struct ggml_backend_device { struct ggml_backend_device {

View File

@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
return device->iface.init_backend(device, params); return device->iface.init_backend(device, params);
} }
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
if (device->iface.reset == NULL) {
return;
}
device->iface.reset(device);
}
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
return device->iface.get_buffer_type(device); return device->iface.get_buffer_type(device);
} }

View File

@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
return id; return id;
} }
void ggml_cuda_reset_device(int device) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaDeviceReset());
}
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
cudaError_t err; cudaError_t err;
@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->description = ggml_backend_cuda_device_get_description(dev); props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev); props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev); props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
props->memory_total = props->memory_free = 0;
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY #ifdef GGML_CUDA_NO_PEER_COPY
@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
} }
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_reset_device(ctx->device);
}
static const ggml_backend_device_i ggml_backend_cuda_device_interface = { static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description, /* .get_description = */ ggml_backend_cuda_device_get_description,
@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
/* .reset = */ ggml_backend_cuda_device_reset,
}; };
// backend reg // backend reg
@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i; dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
ggml_cuda_set_device(i);
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name; dev_ctx->description = prop.name;

View File

@ -40,6 +40,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t #define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceReset hipDeviceReset
#define cudaDeviceSynchronize hipDeviceSynchronize #define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t #define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled

View File

@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
return false; return false;
} }
const auto prev{std::get_terminate()}; const auto prev{std::get_terminate()};
GGML_ASSERT(prev != ggml_uncaught_exception); // GGML_ASSERT(prev != ggml_uncaught_exception);
if (prev != ggml_uncaught_exception) {
previous_terminate_handler = prev; previous_terminate_handler = prev;
} else {
GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
}
std::set_terminate(ggml_uncaught_exception); std::set_terminate(ggml_uncaught_exception);
return true; return true;
}(); }();

View File

@ -2,7 +2,6 @@ package model
import ( import (
"cmp" "cmp"
"context"
"fmt" "fmt"
"iter" "iter"
"log/slog" "log/slog"
@ -202,12 +201,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
} }
} }
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 { if addSpecial && len(ids) > 0 {
ids = bpe.vocab.addSpecials(ids) ids = bpe.vocab.addSpecials(ids)
} }
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil return ids, nil
} }
@ -243,6 +241,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
} }
} }
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil return sb.String(), nil
} }

View File

@ -1,12 +1,11 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog" "math"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@ -64,7 +63,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal // This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately // that is modified to ensure that there is a unique hash value that accurately
// represents the contents. // represents the contents.
PostTokenize([]input.Input) ([]input.Input, error) PostTokenize([]*input.Input) ([]*input.Input, error)
} }
// Base implements the common fields and methods for all models // Base implements the common fields and methods for all models
@ -105,6 +104,10 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
} }
arch := b.Config().Architecture() arch := b.Config().Architecture()
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
arch = arch + "_embed"
}
f, ok := models[arch] f, ok := models[arch]
if !ok { if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch) return nil, fmt.Errorf("unsupported model architecture %q", arch)
@ -198,7 +201,7 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
names := fn(tagsCopy) names := fn(tagsCopy)
for _, name := range names { for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor) logutil.Trace("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor)) vv.Set(reflect.ValueOf(tensor))
break break
} }
@ -278,7 +281,7 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) { if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
} }
@ -287,8 +290,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, errors.New("batch size cannot be less than 1") return nil, errors.New("batch size cannot be less than 1")
} }
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, batch, false) err := cache.StartForward(ctx, batch, false)
@ -302,7 +303,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, err return nil, err
} }
ctx.Forward(t).Compute(t) ctx.Forward(t)
return t, nil return t, nil
} }

View File

@ -0,0 +1,73 @@
package gemma3
import (
"errors"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.SentencePieceModel
*TextModel
PoolingType uint32
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return m, nil
}

View File

@ -18,7 +18,7 @@ type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePieceModel
*VisionModel `gguf:"v,vision"` *VisionModel `gguf:"v"`
*TextModel *TextModel
*MultiModalProjector `gguf:"mm"` *MultiModalProjector `gguf:"mm"`
@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: visionOutputs}}, nil return []input.Multimodal{{Tensor: visionOutputs}}, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
for _, inp := range inputs { for _, inp := range inputs {
if len(inp.Multimodal) == 0 { if len(inp.Multimodal) == 0 {
@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
inputMultimodal := inp.Multimodal[0].Tensor inputMultimodal := inp.Multimodal[0].Tensor
result = append(result, result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>"" &input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
) )
// add image token placeholders // add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result, result = append(result,
input.Input{Token: 256000}, // <end_of_image> &input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n" &input.Input{Token: 108}, // "\n\n"
) )
} }
} }
@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.Output.Forward(ctx, hiddenStates), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }
func init() { func init() {
model.Register("gemma3", New) model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
} }

View File

@ -159,8 +159,11 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings // set image embeddings
@ -198,5 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState) return hiddenState
} }

View File

@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding model.BytePairEncoding
ImageProcessor ImageProcessor
*VisionModel `gguf:"v,vision"` *VisionModel `gguf:"v"`
*Projector `gguf:"mm"` *Projector `gguf:"mm"`
*TextModel *TextModel
} }
@ -134,16 +134,16 @@ type separator struct {
y bool y bool
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
for _, inp := range inputs { for _, inp := range inputs {
if len(inp.Multimodal) == 0 { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
continue continue
} }
var imageInputs []input.Input var imageInputs []*input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
for i, mm := range inp.Multimodal { for i, mm := range inp.Multimodal {
patchesPerChunk := mm.Tensor.Dim(1) patchesPerChunk := mm.Tensor.Dim(1)
@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
if i < len(inp.Multimodal)-1 { if i < len(inp.Multimodal)-1 {
separator := mm.Data.(*separator) separator := mm.Data.(*separator)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if separator.x { if separator.x {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
} }
if separator.y { if separator.y {
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
} }
} else { } else {
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
} }
} }

View File

@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding model.BytePairEncoding
*TextModel *TextModel
*VisionModel `gguf:"v,vision"` *VisionModel `gguf:"v"`
*MultiModalProjector `gguf:"mm"` *MultiModalProjector `gguf:"mm"`
ImageProcessor ImageProcessor
@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings // Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together. // that can be processed together.
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
for _, inp := range inputs { for _, inp := range inputs {
if len(inp.Multimodal) == 0 { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
} else { } else {
for i, row := range inp.Multimodal { for i, row := range inp.Multimodal {
// [IMG] // [IMG]
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inp.Multimodal)-1 { if i == len(inp.Multimodal)-1 {
// [IMG_END] // [IMG_END]
result = append(result, input.Input{Token: 13}) result = append(result, &input.Input{Token: 13})
} else { } else {
// [IMG_BREAK] // [IMG_BREAK]
result = append(result, input.Input{Token: 12}) result = append(result, &input.Input{Token: 12})
} }
} }
} }

View File

@ -17,7 +17,7 @@ type Model struct {
model.Base model.Base
model.BytePairEncoding model.BytePairEncoding
*VisionModel `gguf:"v,vision"` *VisionModel `gguf:"v"`
*TextModel *TextModel
Projector *nn.Linear `gguf:"mm.0"` Projector *nn.Linear `gguf:"mm.0"`
@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: projectedOutputs}}, nil return []input.Multimodal{{Tensor: projectedOutputs}}, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
for i := range inputs { for i := range inputs {
if inputs[i].Multimodal != nil { if inputs[i].Multimodal != nil {
inputs[i].Token = 128256 // <|image|> inputs[i].Token = 128256 // <|image|>

View File

@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding model.BytePairEncoding
*TextModel *TextModel
*VisionModel `gguf:"v,vision"` *VisionModel `gguf:"v"`
ImageProcessor ImageProcessor
} }
@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
} }
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
var ( var (
imageToken int32 = 151655 imageToken int32 = 151655
@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return nil, fmt.Errorf("failed to encode image prompt: %w", err) return nil, fmt.Errorf("failed to encode image prompt: %w", err)
} }
for i := range pre { for i := range pre {
result = append(result, input.Input{Token: pre[i]}) result = append(result, &input.Input{Token: pre[i]})
} }
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token // First add the vision start token
result = append(result, input.Input{Token: visionStartToken}) result = append(result, &input.Input{Token: visionStartToken})
// Add the image token with the multimodal tensor data at the first position // Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{ result = append(result, &input.Input{
Token: imageToken, Token: imageToken,
Multimodal: inp.Multimodal, Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash, MultimodalHash: inp.MultimodalHash,
@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
}) })
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...) result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, input.Input{Token: visionEndToken}) result = append(result, &input.Input{Token: visionEndToken})
} }
} }

View File

@ -2,7 +2,6 @@ package model
import ( import (
"container/heap" "container/heap"
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"strconv" "strconv"
@ -25,7 +24,7 @@ func (spm SentencePieceModel) Vocabulary() *Vocabulary {
} }
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{} counter := map[int]int{}
var maxTokenLen int var maxTokenLen int
@ -39,7 +38,7 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
} }
} }
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen) "max token len", maxTokenLen)
@ -182,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
} }
} }
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 { if addSpecial && len(ids) > 0 {
ids = spm.vocab.addSpecials(ids) ids = spm.vocab.addSpecials(ids)
} }
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil return ids, nil
} }
@ -246,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
} }
} }
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil return sb.String(), nil
} }

View File

@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
} }
slog.Debug("adding bos token to prompt", "id", v.BOS) slog.Debug("adding bos token to prompt", "id", v.BOS[0])
ids = append([]int32{v.BOS[0]}, ids...) ids = append([]int32{v.BOS[0]}, ids...)
} }
@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
} }
slog.Debug("adding eos token to prompt", "id", v.EOS) slog.Debug("adding eos token to prompt", "id", v.EOS[0])
ids = append(ids, v.EOS[0]) ids = append(ids, v.EOS[0])
} }

View File

@ -557,12 +557,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var think *api.ThinkValue var think *api.ThinkValue
if r.Reasoning != nil { if r.Reasoning != nil {
options["reasoning"] = *r.Reasoning.Effort
think = &api.ThinkValue{ think = &api.ThinkValue{
Value: *r.Reasoning.Effort, Value: *r.Reasoning.Effort,
} }
} else if r.ReasoningEffort != nil { } else if r.ReasoningEffort != nil {
options["reasoning"] = *r.ReasoningEffort
think = &api.ThinkValue{ think = &api.ThinkValue{
Value: *r.ReasoningEffort, Value: *r.ReasoningEffort,
} }

View File

@ -246,7 +246,7 @@ func filesForModel(path string) ([]string, error) {
for _, match := range matches { for _, match := range matches {
if ct, err := detectContentType(match); err != nil { if ct, err := detectContentType(match); err != nil {
return nil, err return nil, err
} else if ct != contentType { } else if len(contentType) > 0 && ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
} }
} }
@ -255,7 +255,8 @@ func filesForModel(path string) ([]string, error) {
} }
var files []string var files []string
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { // some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are // safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...) files = append(files, st...)

View File

@ -46,7 +46,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
} }
// Locking: Operations on InputCacheSlot (including finding one // Locking: Operations on InputCacheSlot (including finding one
// through LoadCacheSlot) require a lock to be be held that serializes // through LoadCacheSlot) require a lock to be held that serializes
// these operations with each other and llama.Decode // these operations with each other and llama.Decode
type InputCacheSlot struct { type InputCacheSlot struct {

View File

@ -78,7 +78,7 @@ func (c *InputCache) Close() {
} }
// Locking: Operations on InputCacheSlot (including finding one // Locking: Operations on InputCacheSlot (including finding one
// through LoadCacheSlot) require a lock to be be held that serializes // through LoadCacheSlot) require a lock to be held that serializes
// these operations with each other and processBatch // these operations with each other and processBatch
type InputCacheSlot struct { type InputCacheSlot struct {
@ -86,7 +86,7 @@ type InputCacheSlot struct {
Id int Id int
// Inputs that are stored in the KV cache // Inputs that are stored in the KV cache
Inputs []input.Input Inputs []*input.Input
// is this cache actively being processed as part of a sequence? // is this cache actively being processed as part of a sequence?
InUse bool InUse bool
@ -95,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@ -113,6 +113,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return nil, nil, err return nil, nil, err
} }
if !cachePrompt {
numPast = 0
}
slot.InUse = true slot.InUse = true
slot.lastUsed = time.Now() slot.lastUsed = time.Now()
@ -146,7 +150,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return slot, prompt, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1) longest := int32(-1)
var longestSlot *InputCacheSlot var longestSlot *InputCacheSlot
@ -169,7 +173,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
return longestSlot, longest, nil return longestSlot, longest, nil
} }
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now() oldest := time.Now()
var oldestSlot *InputCacheSlot var oldestSlot *InputCacheSlot
@ -205,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot { if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs)) len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input.Input, longest) oldestSlot.Inputs = make([]*input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil { if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@ -215,7 +219,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil return oldestSlot, longest, nil
} }
func countCommonPrefix(a []input.Input, b []input.Input) int32 { func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
var count int32 var count int32
for i := range a { for i := range a {
@ -250,7 +254,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
} }
type ErrReprocessInputs struct { type ErrReprocessInputs struct {
Inputs []input.Input Inputs []*input.Input
} }
func (e *ErrReprocessInputs) Error() string { func (e *ErrReprocessInputs) Error() string {
@ -283,13 +287,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
"id", slot.Id, "error", err) "id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard) // Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep]) copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache // Reset the cache
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32) _ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
slot.Inputs = []input.Input{} slot.Inputs = []*input.Input{}
// Return error with inputs that need to be reprocessed // Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs} return &ErrReprocessInputs{Inputs: newInputs}

View File

@ -13,50 +13,50 @@ import (
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
t1 []input.Input t1 []*input.Input
t2 []input.Input t2 []*input.Input
expected int32 expected int32
}{ }{
{ {
name: "Equal", name: "Equal",
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3, expected: 3,
}, },
{ {
name: "Prefix", name: "Prefix",
t1: []input.Input{{Token: 1}}, t1: []*input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input.Input{{MultimodalHash: 1}}, t1: []*input.Input{{MultimodalHash: 1}},
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{ {
name: "Mixed, Same Length", name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}}, t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1, expected: 1,
}, },
{ {
name: "Empty", name: "Empty",
t1: []input.Input{}, t1: []*input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0, expected: 0,
}, },
{ {
name: "Both Empty", name: "Both Empty",
t1: []input.Input{}, t1: []*input.Input{},
t2: []input.Input{}, t2: []*input.Input{},
expected: 0, expected: 0,
}, },
} }
@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input.Input prompt []*input.Input
longest expected longest expected
best expected best expected
}{ }{
@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input.Input{{Token: 1}}, prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0}, best: expected{result: 0, len: 0},
}, },
@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}}, Inputs: []*input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2}, longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input.Input{{Token: 2}}, prompt: []*input.Input{{Token: 2}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}}, prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 1}, longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1}, best: expected{result: 1, len: 1},
}, },
@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}}, Inputs: []*input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}}, Inputs: []*input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1}, longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input.Input prompt []*input.Input
wantErr bool wantErr bool
expectedSlotId int expectedSlotId int
expectedPrompt int // expected length of remaining prompt expectedPrompt int // expected length of remaining prompt
@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false, wantErr: false,
expectedSlotId: 0, expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains expectedPrompt: 1, // Only token 3 remains
@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false, wantErr: false,
expectedSlotId: 0, expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains expectedPrompt: 1, // Only token 3 remains
@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []*input.Input{{Token: 1}, {Token: 2}},
wantErr: false, wantErr: false,
expectedSlotId: 0, expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling expectedPrompt: 1, // Should leave 1 token for sampling
@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true, wantErr: true,
expectedSlotId: -1, expectedSlotId: -1,
expectedPrompt: -1, expectedPrompt: -1,
@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
// Check error state // Check error state
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
numCtx int32 numCtx int32
inputs []input.Input inputs []*input.Input
numKeep int32 numKeep int32
cacheErr bool cacheErr bool
wantErr any wantErr any
@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
{ {
name: "Normal shift", name: "Normal shift",
numCtx: 10, numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2, numKeep: 2,
cacheErr: false, // No error cacheErr: false, // No error
wantErr: nil, wantErr: nil,
@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
{ {
name: "Cache removal fails", name: "Cache removal fails",
numCtx: 10, numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2, numKeep: 2,
cacheErr: true, cacheErr: true,
wantErr: &ErrReprocessInputs{}, wantErr: &ErrReprocessInputs{},
@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
} }
slot := &InputCacheSlot{ slot := &InputCacheSlot{
Id: 123, Id: 123,
Inputs: make([]input.Input, len(tt.inputs)), Inputs: make([]*input.Input, len(tt.inputs)),
} }
copy(slot.Inputs, tt.inputs) copy(slot.Inputs, tt.inputs)

View File

@ -11,12 +11,14 @@ import (
"image" "image"
"log" "log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -51,10 +53,10 @@ type Sequence struct {
iBatch int iBatch int
// prompt inputs left to evaluate // prompt inputs left to evaluate
inputs []input.Input inputs []*input.Input
// inputs that have been added to a batch but not yet submitted to Forward // inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input.Input pendingInputs []*input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
@ -182,8 +184,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) { func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input var inputs []*input.Input
var ctxs []ml.Context var ctxs []ml.Context
var mmStore multimodalStore var mmStore multimodalStore
@ -210,7 +212,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
} }
for _, t := range tokens { for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t}) inputs = append(inputs, &input.Input{Token: t})
} }
// image - decode and store // image - decode and store
@ -243,7 +245,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
mmStore.addMultimodal(imageEmbeddings) mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true postTokenize = true
} }
} }
@ -259,6 +261,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
return inputs, ctxs, mmStore, nil return inputs, ctxs, mmStore, nil
} }
type batchState struct {
// id provides a counter for trace logging batches
id int
// ctx holds the backend context used for this batch
ctx ml.Context
// modelOutput holds the outputs from this batch
modelOutput ml.Tensor
// batchInputs holds the input token pointers which may start as
// placeholders later filled in before calling ctx.Compute
batchInputs []*input.Input
// batch contains the inputs for a model forward pass
batch input.Batch
// full set of seqs at the time this batch was initiated
seqs []*Sequence
// Signaled when this batches inputs are ready and compute can proceed
inputsReadyCh chan struct{}
// Signaling when Compute is about to begin on this batch, and
// seqs have been updated to prepare for the next batch
computeStartedCh chan struct{}
// Signaled when this batches outputs are complete and the next batch can proceed
outputsReadyCh chan struct{}
}
type Server struct { type Server struct {
// modelPath is the location of the model to be loaded // modelPath is the location of the model to be loaded
modelPath string modelPath string
@ -290,6 +323,12 @@ type Server struct {
// TODO (jmorganca): make this n_batch // TODO (jmorganca): make this n_batch
batchSize int batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches
batchID int
// protects access to everything below this line // protects access to everything below this line
// this is context state needed for decoding // this is context state needed for decoding
mu sync.Mutex mu sync.Mutex
@ -362,33 +401,73 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
s.seqsSem.Release(1) s.seqsSem.Release(1)
} }
// track batch state between forwardBatch, computeBatch and predictForwardBatch
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
s.ready.Wait() s.ready.Wait()
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
var activeBatch batchState
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case err := <-s.hardErrCh:
panic(err)
default: default:
err := s.processBatch() var err error
activeBatch, err = s.forwardBatch(activeBatch)
if err != nil { if err != nil {
panic(err) panic(err)
} }
if supportsAsync {
go s.computeBatch(activeBatch)
} else {
s.computeBatch(activeBatch)
}
} }
} }
} }
func (s *Server) processBatch() error { // forwardBatch will calculate a batch.
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
// If we have a pending batch still processing, wait until Compute has started
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if pendingBatch.ctx != nil {
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
<-pendingBatch.computeStartedCh
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
// No pendingBatch, so the inputs will be ready in the seqs immediately
nextBatch.inputsReadyCh = make(chan struct{}, 1)
nextBatch.inputsReadyCh <- struct{}{}
}
s.mu.Lock() s.mu.Lock()
for s.allNil() { for s.allNil() {
s.cond.Wait() // Wait until an item is added s.cond.Wait() // Wait until an item is added
} }
defer s.mu.Unlock() defer s.mu.Unlock()
ctx := s.model.Backend().NewContext() nextBatch.ctx = s.model.Backend().NewContext()
defer ctx.Close() defer func() {
if err != nil {
nextBatch.ctx.Close()
nextBatch.ctx = nil
}
}()
nextBatch.id = s.batchID
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
nextBatch.computeStartedCh = make(chan struct{}, 1)
nextBatch.outputsReadyCh = make(chan struct{}, 1)
var batchInputs []int32 // Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batch input.Batch var batch input.Batch
resumeSeq := -1 resumeSeq := -1
@ -396,7 +475,6 @@ func (s *Server) processBatch() error {
for range s.seqs { for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs) seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx] seq := s.seqs[seqIdx]
if seq == nil { if seq == nil {
continue continue
} }
@ -404,12 +482,13 @@ func (s *Server) processBatch() error {
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength) s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue continue
} }
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input.Input{} seq.cache.Inputs = []*input.Input{}
} }
batchSize := s.batchSize batchSize := s.batchSize
@ -442,25 +521,28 @@ func (s *Server) processBatch() error {
break break
} }
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil { if err != nil {
var reprocess *ErrReprocessInputs var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) { if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing // Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...) seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest // Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil
continue continue
} else { } else {
return err return
} }
} }
} }
batchInputs = append(batchInputs, inp.Token) batchInputs = append(batchInputs, seq.inputs[i])
if inp.Multimodal != nil { if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) var mm []input.Multimodal
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
if err != nil { if err != nil {
return err return
} }
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm}) batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
} }
@ -472,6 +554,7 @@ func (s *Server) processBatch() error {
if i+1 == len(seq.inputs) { if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
} }
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, inp)
} }
@ -485,73 +568,168 @@ func (s *Server) processBatch() error {
} }
if len(batchInputs) == 0 { if len(batchInputs) == 0 {
return nil logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
nextBatch.ctx.Close()
nextBatch.ctx = nil
return
} }
s.batchID++
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode batch: %w", err) err = fmt.Errorf("failed to build graph: %w", err)
return
}
nextBatch.batchInputs = batchInputs
nextBatch.batch = batch
return
}
// Async processing of the next batch
func (s *Server) computeBatch(activeBatch batchState) {
if activeBatch.ctx == nil {
// Nothing to compute
return
}
defer activeBatch.ctx.Close()
// Wait until inputs are ready
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
<-activeBatch.inputsReadyCh
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
activeBatch.outputsReadyCh <- struct{}{}
}()
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(activeBatch.batchInputs))
for i := range batchInputs {
batchInputs[i] = activeBatch.batchInputs[i].Token
} }
logits := modelOutput.Floats() // Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
for i, seq := range s.seqs { for i, seq := range s.seqs {
iBatches[i] = -1
if seq == nil { if seq == nil {
continue continue
} }
// Skip over any newly added or skipped sequences
if activeBatch.seqs[i] == nil {
continue
}
// After calling Forward, pending inputs are now in the cache // Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != activeBatch.seqs[i] {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
// Pending inputs will actually be in the cache after we call Compute.
// However, we have already resolved any placeholder tokens.
//
// It's possible for incoming sequences to look at the values that we've
// added to the cache here and start relying on them before we've done
// the computation. This is OK as long as we ensure that this batch's
// computation happens before any future batch's and we never fail
// (unless we take down the whole runner).
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input.Input{} seq.pendingInputs = []*input.Input{}
} }
// don't sample prompt processing // don't sample prompt processing
if len(seq.inputs) != 0 { if len(seq.inputs) != 0 {
if !s.cache.enabled { if !s.cache.enabled {
return errors.New("caching disabled but unable to fit entire input in a batch") s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
} }
continue continue
} }
seq.numPredicted++ seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
iBatches[i] = seq.iBatch
}
// At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock()
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
activeBatch.ctx.ComputeWithNotify(
func() {
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
outputs := activeBatch.modelOutput.Floats()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
s.mu.Lock()
defer s.mu.Unlock()
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
for i, seq := range s.seqs {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
if seq.numPredicted == 1 { if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now() seq.startGenerationTime = time.Now()
} }
// if done processing the prompt, generate an embedding and return // if done processing the prompt, generate an embedding and return
if seq.embeddingOnly { if seq.embeddingOnly {
// TODO(jessegross): Embedding support seq.embedding <- outputs
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, llm.DoneReasonStop) s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
// sample a token // sample a token
vocabSize := len(logits) / len(batch.Outputs) vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
return fmt.Errorf("failed to sample token: %w", err) s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
} }
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break // if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back // TODO (jmorganca): we should send this back
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop) s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil { if err != nil {
return err s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
} }
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
@ -575,6 +753,7 @@ func (s *Server) processBatch() error {
if tokenTruncated || origLen == newLen { if tokenTruncated || origLen == newLen {
tokenLen-- tokenLen--
} }
seq.cache.Inputs = seq.cache.Inputs[:tokenLen] seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop) s.removeSequence(i, llm.DoneReasonStop)
@ -593,8 +772,6 @@ func (s *Server) processBatch() error {
s.removeSequence(i, llm.DoneReasonConnectionClosed) s.removeSequence(i, llm.DoneReasonConnectionClosed)
} }
} }
return nil
} }
func (s *Server) completion(w http.ResponseWriter, r *http.Request) { func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@ -665,7 +842,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false found := false
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
s.seqsSem.Release(1) s.seqsSem.Release(1)
@ -721,6 +898,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embedding request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) { func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
@ -736,7 +974,10 @@ func (s *Server) reserveWorstCaseGraph() error {
defer ctx.Close() defer ctx.Close()
var err error var err error
inputs := make([]input.Input, s.batchSize) inputs := make([]*input.Input, s.batchSize)
for i := range inputs {
inputs[i] = &input.Input{}
}
mmStore := newMultimodalStore() mmStore := newMultimodalStore()
// Multimodal strategy: // Multimodal strategy:
@ -778,8 +1019,11 @@ func (s *Server) reserveWorstCaseGraph() error {
} }
if len(inputs) < s.batchSize { if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize) newInputs := make([]*input.Input, s.batchSize)
copy(newInputs, inputs) copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ {
newInputs[i] = &input.Input{}
}
inputs = newInputs inputs = newInputs
} }
} }
@ -842,6 +1086,7 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors // Convert memory allocation panics to errors
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok { if err, ok := r.(error); ok {
panicErr = err panicErr = err
} else { } else {
@ -1011,6 +1256,7 @@ func Execute(args []string) error {
server := &Server{ server := &Server{
modelPath: *mpath, modelPath: *mpath,
status: llm.ServerStatusLaunched, status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
} }
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)
@ -1029,10 +1275,7 @@ func Execute(args []string) error {
mux := http.NewServeMux() mux := http.NewServeMux()
// TODO: support embeddings // TODO: support embeddings
mux.HandleFunc("POST /load", server.load) mux.HandleFunc("POST /load", server.load)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /embedding", server.embeddings)
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /completion", server.completion) mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health) mux.HandleFunc("GET /health", server.health)

View File

@ -32,6 +32,7 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
@ -45,6 +46,18 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func experimentEnabled(name string) bool { func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
} }
@ -176,7 +189,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
// expire the runner // expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m) s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
@ -194,12 +207,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
useHarmony := shouldUseHarmony(*m) && !req.Raw useHarmony := shouldUseHarmony(m) && !req.Raw
var harmonyMessageHandler *HarmonyMessageHandler var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *HarmonyToolCallAccumulator var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony { if useHarmony {
harmonyMessageHandler = NewHarmonyMessageHandler() harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.harmonyParser.AddImplicitStart() harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser() harmonyToolParser = harmonyMessageHandler.CreateToolParser()
} }
@ -1531,7 +1544,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
// expire the runner // expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
model, err := GetModel(req.Model) model, err := GetModel(req.Model)
if err != nil { if err != nil {
switch { switch {
@ -1603,19 +1616,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
var harmonyMessageHandler *HarmonyMessageHandler var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *HarmonyToolCallAccumulator var harmonyToolParser *harmony.HarmonyToolCallAccumulator
useHarmony := shouldUseHarmony(*m) useHarmony := shouldUseHarmony(m)
processedTools := req.Tools processedTools := req.Tools
if useHarmony { if useHarmony {
harmonyMessageHandler = NewHarmonyMessageHandler() harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
var lastMessage *api.Message var lastMessage *api.Message
if len(msgs) > 0 { if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1] lastMessage = &msgs[len(msgs)-1]
} }
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage) harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser() harmonyToolParser = harmonyMessageHandler.CreateToolParser()
// make a copy of tools to pass to the chat prompt. Function names may be // make a copy of tools to pass to the chat prompt. Function names may be
@ -1623,7 +1636,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
processedTools = make([]api.Tool, len(req.Tools)) processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools) copy(processedTools, req.Tools)
for i, tool := range processedTools { for i, tool := range processedTools {
processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name) processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
} }
} }
@ -1660,6 +1673,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
OpeningTag: openingTag, OpeningTag: openingTag,
ClosingTag: closingTag, ClosingTag: closingTag,
} }
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
thinkingState.AddContent(openingTag)
}
} }
var toolParser *tools.Parser var toolParser *tools.Parser
@ -1705,7 +1722,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
toolName, toolContent := harmonyToolParser.Drain() toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil { if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.") *toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName) *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil { if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())

View File

@ -969,3 +969,233 @@ func TestGenerate(t *testing.T) {
} }
}) })
} }
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
gin.SetMode(gin.TestMode)
// Helper to create a standard thinking test setup
setupThinkingTest := func(t *testing.T) (*mockRunner, *Server) {
mock := &mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with thinking support
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Create model with thinking template that adds <think> at the end
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-thinking",
Files: map[string]string{"file.gguf": digest},
Template: `{{- range .Messages }}
{{- if eq .Role "user" }}user: {{ .Content }}
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
{{ end }}{{ end }}<think>`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
return mock, s
}
mock, s := setupThinkingTest(t)
// Helper to test chat responses
testChatRequest := func(t *testing.T, name string, userContent string, modelResponse string, expectedThinking string, expectedContent string, think bool) {
t.Run(name, func(t *testing.T) {
mock.CompletionResponse = llm.CompletionResponse{
Content: modelResponse,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
}
mock.CompletionFn = nil
streamRequest := false
req := api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{
{Role: "user", Content: userContent},
},
Stream: &streamRequest,
}
if think {
req.Think = &api.ThinkValue{Value: think}
}
w := createRequest(t, s.ChatHandler, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Thinking != expectedThinking {
t.Errorf("expected thinking %q, got %q", expectedThinking, resp.Message.Thinking)
}
if resp.Message.Content != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, resp.Message.Content)
}
})
}
// Test cases - Note: Template adds <think> at the end, and leading whitespace after <think> is eaten by the parser
testChatRequest(t, "basic thinking response",
"Help me solve this problem",
" Let me think about this step by step... </think> The answer is 42.",
"Let me think about this step by step... ",
"The answer is 42.",
true)
testChatRequest(t, "thinking with multiple sentences",
"Explain quantum computing",
" First, I need to understand the basics. Quantum bits can be in superposition. </think> Quantum computing uses quantum mechanics principles.",
"First, I need to understand the basics. Quantum bits can be in superposition. ",
"Quantum computing uses quantum mechanics principles.",
true)
testChatRequest(t, "no thinking content",
"What is 2+2?",
"</think> The answer is 4.",
"",
"The answer is 4.",
true)
testChatRequest(t, "thinking disabled but template still adds think tag",
"Simple question",
" My thoughts </think> The answer.",
"",
" My thoughts </think> The answer.",
false)
// Test streaming response with template-added <think>
t.Run("streaming with thinking", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
// Verify the prompt ends with <think> due to template
if !strings.HasSuffix(r.Prompt, "<think>") {
t.Errorf("expected prompt to end with <think>, got: %q", r.Prompt)
}
// Simulate streaming chunks
responses := []llm.CompletionResponse{
{Content: " I need to consider", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
{Content: " multiple factors here...", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
{Content: " </think> Based on my analysis,", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
{Content: " the solution is straightforward.", Done: true, DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, EvalDuration: 1},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
time.Sleep(10 * time.Millisecond)
}
}
return nil
}
think := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
Think: &api.ThinkValue{Value: think},
Stream: &stream,
})
wg.Wait()
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
// Parse streaming responses
decoder := json.NewDecoder(w.Body)
var allThinking, allContent strings.Builder
for {
var resp api.ChatResponse
if err := decoder.Decode(&resp); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
allThinking.WriteString(resp.Message.Thinking)
allContent.WriteString(resp.Message.Content)
}
// Note: Leading whitespace after <think> is eaten by the parser
if got := allThinking.String(); got != "I need to consider multiple factors here... " {
t.Errorf("expected thinking %q, got %q", "I need to consider multiple factors here... ", got)
}
if got := allContent.String(); got != "Based on my analysis, the solution is straightforward." {
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
}
})
}

View File

@ -103,7 +103,9 @@ func eat(s *Parser) (string, string, bool) {
// note that we use the original content, not the trimmed one because we // note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no // don't want to eat any whitespace in the real content if there were no
// thinking tags // thinking tags
return "", s.acc.String(), false untrimmed := s.acc.String()
s.acc.Reset()
return "", untrimmed, false
} }
case thinkingState_ThinkingStartedEatingWhitespace: case thinkingState_ThinkingStartedEatingWhitespace:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)

View File

@ -58,6 +58,15 @@ func TestThinkingStreaming(t *testing.T) {
wantContent: " abc", wantContent: " abc",
wantStateAfter: thinkingState_ThinkingDone, wantStateAfter: thinkingState_ThinkingDone,
}, },
// regression test for a bug where we were transitioning directly to
// ThinkingDone without clearing the buffer. This would cuase the first
// step to be outputted twice
{
input: "def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
}, },
}, },
{ {

View File

@ -224,22 +224,45 @@ func findArguments(buffer []byte) (map[string]any, int) {
return nil, 0 return nil, 0
} }
start := -1
var braces int var braces int
var start int = -1 var inString, escaped bool
for i := range buffer {
c := buffer[i]
if escaped {
escaped = false
continue
}
if c == '\\' {
escaped = true
continue
}
if c == '"' {
inString = !inString
continue
}
if inString {
continue
}
for i, c := range buffer {
if c == '{' { if c == '{' {
if braces == 0 { if braces == 0 {
start = i start = i
} }
braces++ braces++
} else if c == '}' && braces > 0 { } else if c == '}' {
braces-- braces--
if braces == 0 && start != -1 { if braces == 0 && start != -1 {
object := buffer[start : i+1] object := buffer[start : i+1]
var data map[string]any var data map[string]any
if err := json.Unmarshal(object, &data); err != nil { if err := json.Unmarshal(object, &data); err != nil {
// not a valid object, keep looking
start = -1 start = -1
continue continue
} }
@ -282,6 +305,10 @@ func findArguments(buffer []byte) (map[string]any, int) {
return data, i return data, i
} }
if braces < 0 {
braces = 0
}
} }
} }

View File

@ -1,6 +1,7 @@
package tools package tools
import ( import (
"strings"
"testing" "testing"
"text/template" "text/template"
@ -40,13 +41,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "get_temperature", Name: "get_temperature",
Description: "Retrieve the temperature for a given location", Description: "Retrieve the temperature for a given location",
Parameters: struct { Parameters: api.ToolFunctionParameters{
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object", Type: "object",
Required: []string{"city"}, Required: []string{"city"},
Properties: map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
@ -68,13 +63,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "get_conditions", Name: "get_conditions",
Description: "Retrieve the current weather conditions for a given location", Description: "Retrieve the current weather conditions for a given location",
Parameters: struct { Parameters: api.ToolFunctionParameters{
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object", Type: "object",
Properties: map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
@ -104,13 +93,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "get_address", Name: "get_address",
Description: "Get the address of a given location", Description: "Get the address of a given location",
Parameters: struct { Parameters: api.ToolFunctionParameters{
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object", Type: "object",
Properties: map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
@ -126,13 +109,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "add", Name: "add",
Description: "Add two numbers", Description: "Add two numbers",
Parameters: struct { Parameters: api.ToolFunctionParameters{
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object", Type: "object",
Properties: map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"a": { "a": {
@ -1140,11 +1117,163 @@ func TestFindArguments(t *testing.T) {
}, },
{ {
name: "deepseek", name: "deepseek",
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`), buffer: []byte(`"arguments": {"location": "Tokyo"}}</tool_call>`),
want: map[string]any{ want: map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, },
}, },
{
name: "string with braces",
buffer: []byte(`{"name": "process_code", "arguments": {"code": "if (x > 0) { return true; }"}}`),
want: map[string]any{
"code": "if (x > 0) { return true; }",
},
},
{
name: "string with nested json",
buffer: []byte(`{"name": "send_data", "arguments": {"payload": "{\"nested\": {\"key\": \"value\"}}"}}`),
want: map[string]any{
"payload": `{"nested": {"key": "value"}}`,
},
},
{
name: "string with escaped quotes and braces",
buffer: []byte(`{"name": "analyze", "arguments": {"text": "The JSON is: {\"key\": \"val{ue}\"}"}}`),
want: map[string]any{
"text": `The JSON is: {"key": "val{ue}"}`,
},
},
{
name: "multiple objects with string containing braces",
buffer: []byte(`{"name": "test", "arguments": {"query": "find } in text"}} {"name": "other"}`),
want: map[string]any{
"query": "find } in text",
},
},
{
name: "unmatched closing brace in string",
buffer: []byte(`{"name": "search", "arguments": {"pattern": "regex: }"}}`),
want: map[string]any{
"pattern": "regex: }",
},
},
{
name: "complex nested with mixed braces",
buffer: []byte(`{"name": "analyze", "arguments": {"data": "{\"items\": [{\"value\": \"}\"}, {\"code\": \"if (x) { return y; }\"}]}"}}`),
want: map[string]any{
"data": `{"items": [{"value": "}"}, {"code": "if (x) { return y; }"}]}`,
},
},
{
name: "string with newline and braces",
buffer: []byte(`{"name": "format", "arguments": {"template": "{\n \"key\": \"value\"\n}"}}`),
want: map[string]any{
"template": "{\n \"key\": \"value\"\n}",
},
},
{
name: "string with unicode escape",
buffer: []byte(`{"name": "test", "arguments": {"text": "Unicode: \u007B and \u007D"}}`),
want: map[string]any{
"text": "Unicode: { and }",
},
},
{
name: "array arguments",
buffer: []byte(`{"name": "batch", "arguments": ["item1", "item2", "{\"nested\": true}"]}`),
want: nil, // This should return nil because arguments is not a map
},
{
name: "escaped backslash before quote",
buffer: []byte(`{"name": "path", "arguments": {"dir": "C:\\Program Files\\{App}\\"}}`),
want: map[string]any{
"dir": `C:\Program Files\{App}\`,
},
},
{
name: "single quotes not treated as string delimiters",
buffer: []byte(`{"name": "query", "arguments": {"sql": "SELECT * FROM users WHERE name = '{admin}'"}}`),
want: map[string]any{
"sql": "SELECT * FROM users WHERE name = '{admin}'",
},
},
{
name: "incomplete json at buffer end",
buffer: []byte(`{"name": "test", "arguments": {"data": "some {"`),
want: nil,
},
{
name: "multiple escaped quotes",
buffer: []byte(`{"name": "echo", "arguments": {"msg": "He said \"Hello {World}\" loudly"}}`),
want: map[string]any{
"msg": `He said "Hello {World}" loudly`,
},
},
{
name: "json with comments style string",
buffer: []byte(`{"name": "code", "arguments": {"snippet": "// This is a comment with { and }"}}`),
want: map[string]any{
"snippet": "// This is a comment with { and }",
},
},
{
name: "consecutive escaped backslashes",
buffer: []byte(`{"name": "test", "arguments": {"path": "C:\\\\{folder}\\\\"}}`),
want: map[string]any{
"path": `C:\\{folder}\\`,
},
},
{
name: "empty string with braces after",
buffer: []byte(`{"name": "test", "arguments": {"a": "", "b": "{value}"}}`),
want: map[string]any{
"a": "",
"b": "{value}",
},
},
{
name: "unicode in key names",
buffer: []byte(`{"name": "test", "arguments": {"key{": "value", "key}": "value2"}}`),
want: map[string]any{
"key{": "value",
"key}": "value2",
},
},
{
name: "very long string with braces",
buffer: []byte(`{"name": "test", "arguments": {"data": "` + strings.Repeat("a{b}c", 100) + `"}}`),
want: map[string]any{
"data": strings.Repeat("a{b}c", 100),
},
},
{
name: "tab characters and braces",
buffer: []byte(`{"name": "test", "arguments": {"code": "\tif (true) {\n\t\treturn;\n\t}"}}`),
want: map[string]any{
"code": "\tif (true) {\n\t\treturn;\n\t}",
},
},
{
name: "null byte in string",
buffer: []byte(`{"name": "test", "arguments": {"data": "before\u0000{after}"}}`),
want: map[string]any{
"data": "before\x00{after}",
},
},
{
name: "escaped quote at end of string",
buffer: []byte(`{"name": "test", "arguments": {"data": "text with quote at end\\\""}}`),
want: map[string]any{
"data": `text with quote at end\"`,
},
},
{
name: "mixed array and object in arguments",
buffer: []byte(`{"name": "test", "arguments": {"items": ["{", "}", {"key": "value"}]}}`),
want: map[string]any{
"items": []any{"{", "}", map[string]any{"key": "value"}},
},
},
} }
for _, tt := range tests { for _, tt := range tests {