diff --git a/api/types.go b/api/types.go index 699dba428..e2c63b622 100644 --- a/api/types.go +++ b/api/types.go @@ -85,10 +85,11 @@ type GenerateRequest struct { Options map[string]any `json:"options"` // Think controls whether thinking/reasoning models will think before - // responding. Needs to be a pointer so we can distinguish between false + // responding. Can be a boolean (true/false) or a string ("high", "medium", "low") + // for supported models. Needs to be a pointer so we can distinguish between false // (request that thinking _not_ be used) and unset (use the old behavior // before this option was introduced) - Think *bool `json:"think,omitempty"` + Think *ThinkValue `json:"think,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -116,8 +117,9 @@ type ChatRequest struct { Options map[string]any `json:"options"` // Think controls whether thinking/reasoning models will think before - // responding - Think *bool `json:"think,omitempty"` + // responding. Can be a boolean (true/false) or a string ("high", "medium", "low") + // for supported models. + Think *ThinkValue `json:"think,omitempty"` } type Tools []Tool @@ -508,6 +510,8 @@ type GenerateResponse struct { Context []int `json:"context,omitempty"` Metrics + + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } // ModelDetails provides details about a model. @@ -677,6 +681,113 @@ func DefaultOptions() Options { } } +// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low") +type ThinkValue struct { + // Value can be a bool or string + Value interface{} +} + +// IsValid checks if the ThinkValue is valid +func (t *ThinkValue) IsValid() bool { + if t == nil || t.Value == nil { + return true // nil is valid (means not set) + } + + switch v := t.Value.(type) { + case bool: + return true + case string: + return v == "high" || v == "medium" || v == "low" + default: + return false + } +} + +// IsBool returns true if the value is a boolean +func (t *ThinkValue) IsBool() bool { + if t == nil || t.Value == nil { + return false + } + _, ok := t.Value.(bool) + return ok +} + +// IsString returns true if the value is a string +func (t *ThinkValue) IsString() bool { + if t == nil || t.Value == nil { + return false + } + _, ok := t.Value.(string) + return ok +} + +// AsBool returns the value as a bool (true if enabled in any way) +func (t *ThinkValue) AsBool() bool { + if t == nil || t.Value == nil { + return false + } + + switch v := t.Value.(type) { + case bool: + return v + case string: + // Any string value ("high", "medium", "low") means thinking is enabled + return v == "high" || v == "medium" || v == "low" + default: + return false + } +} + +// AsString returns the value as a string +func (t *ThinkValue) AsString() string { + if t == nil || t.Value == nil { + return "" + } + + switch v := t.Value.(type) { + case string: + return v + case bool: + if v { + return "medium" // Default level when just true + } + return "" + default: + return "" + } +} + +// UnmarshalJSON implements json.Unmarshaler +func (t *ThinkValue) UnmarshalJSON(data []byte) error { + // Try to unmarshal as bool first + var b bool + if err := json.Unmarshal(data, &b); err == nil { + t.Value = b + return nil + } + + // Try to unmarshal as string + var s string + if err := json.Unmarshal(data, &s); err == nil { + // Validate string values + if s != "high" && s != "medium" && s != "low" { + return fmt.Errorf("invalid think value: %q (must be \"high\", \"medium\", \"low\", true, or false)", s) + } + t.Value = s + return nil + } + + return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")") +} + +// MarshalJSON implements json.Marshaler +func (t *ThinkValue) MarshalJSON() ([]byte, error) { + if t == nil || t.Value == nil { + return []byte("null"), nil + } + return json.Marshal(t.Value) +} + type Duration struct { time.Duration } diff --git a/api/types_test.go b/api/types_test.go index 9c2fb1f11..841853808 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -374,24 +374,21 @@ func TestPropertyType_MarshalJSON(t *testing.T) { } func TestThinking_UnmarshalJSON(t *testing.T) { - trueVal := true - falseVal := false - tests := []struct { name string input string - expectedThinking *bool + expectedThinking *ThinkValue expectedError bool }{ { name: "true", input: `{ "think": true }`, - expectedThinking: &trueVal, + expectedThinking: &ThinkValue{Value: true}, }, { name: "false", input: `{ "think": false }`, - expectedThinking: &falseVal, + expectedThinking: &ThinkValue{Value: false}, }, { name: "unset", @@ -399,8 +396,23 @@ func TestThinking_UnmarshalJSON(t *testing.T) { expectedThinking: nil, }, { - name: "invalid", - input: `{ "think": "true" }`, + name: "string_high", + input: `{ "think": "high" }`, + expectedThinking: &ThinkValue{Value: "high"}, + }, + { + name: "string_medium", + input: `{ "think": "medium" }`, + expectedThinking: &ThinkValue{Value: "medium"}, + }, + { + name: "string_low", + input: `{ "think": "low" }`, + expectedThinking: &ThinkValue{Value: "low"}, + }, + { + name: "invalid_string", + input: `{ "think": "invalid" }`, expectedThinking: nil, expectedError: true, }, @@ -414,7 +426,12 @@ func TestThinking_UnmarshalJSON(t *testing.T) { require.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedThinking, req.Think) + if test.expectedThinking == nil { + assert.Nil(t, req.Think) + } else { + require.NotNil(t, req.Think) + assert.Equal(t, test.expectedThinking.Value, req.Think.Value) + } } }) } diff --git a/cmd/cmd.go b/cmd/cmd.go index 1d1d116ba..de3fc86a7 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -322,11 +322,23 @@ func RunHandler(cmd *cobra.Command, args []string) error { thinkFlag := cmd.Flags().Lookup("think") if thinkFlag.Changed { - think, err := cmd.Flags().GetBool("think") + thinkStr, err := cmd.Flags().GetString("think") if err != nil { return err } - opts.Think = &think + + // Handle different values for --think + switch thinkStr { + case "", "true": + // --think or --think=true + opts.Think = &api.ThinkValue{Value: true} + case "false": + opts.Think = &api.ThinkValue{Value: false} + case "high", "medium", "low": + opts.Think = &api.ThinkValue{Value: thinkStr} + default: + return fmt.Errorf("invalid value for --think: %q (must be true, false, high, medium, or low)", thinkStr) + } } else { opts.Think = nil } @@ -977,7 +989,7 @@ type runOptions struct { Options map[string]any MultiModal bool KeepAlive *api.Duration - Think *bool + Think *api.ThinkValue HideThinking bool } @@ -1017,10 +1029,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState) } switch ch { - case ' ': + case ' ', '\t': state.wordBuffer = "" - case '\n': + case '\n', '\r': state.lineLength = 0 + state.wordBuffer = "" default: state.wordBuffer += string(ch) } @@ -1078,6 +1091,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { }() var state *displayResponseState = &displayResponseState{} + var thinkingContent strings.Builder var latest api.ChatResponse var fullResponse strings.Builder var thinkTagOpened bool = false @@ -1097,14 +1111,21 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { if !thinkTagOpened { fmt.Print(thinkingOutputOpeningText(false)) thinkTagOpened = true + thinkTagClosed = false } + thinkingContent.WriteString(response.Message.Thinking) displayResponse(response.Message.Thinking, opts.WordWrap, state) } content := response.Message.Content - if thinkTagOpened && !thinkTagClosed && content != "" { + if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) { + if !strings.HasSuffix(thinkingContent.String(), "\n") { + fmt.Println() + } fmt.Print(thinkingOutputClosingText(false)) + thinkTagOpened = false thinkTagClosed = true + state = &displayResponseState{} } // purposefully not putting thinking blocks in the response, which would // only be needed if we later added tool calling to the cli (they get @@ -1112,6 +1133,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { // about to finish some tool calls) fullResponse.WriteString(content) + if response.Message.ToolCalls != nil { + toolCalls := response.Message.ToolCalls + if len(toolCalls) > 0 { + fmt.Print(renderToolCalls(toolCalls, false)) + } + } + displayResponse(content, opts.WordWrap, state) return nil @@ -1196,6 +1224,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { }() var state *displayResponseState = &displayResponseState{} + var thinkingContent strings.Builder var thinkTagOpened bool = false var thinkTagClosed bool = false @@ -1213,17 +1242,31 @@ func generate(cmd *cobra.Command, opts runOptions) error { if !thinkTagOpened { fmt.Print(thinkingOutputOpeningText(plainText)) thinkTagOpened = true + thinkTagClosed = false } + thinkingContent.WriteString(response.Thinking) displayResponse(response.Thinking, opts.WordWrap, state) } - if thinkTagOpened && !thinkTagClosed && content != "" { + if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) { + if !strings.HasSuffix(thinkingContent.String(), "\n") { + fmt.Println() + } fmt.Print(thinkingOutputClosingText(plainText)) + thinkTagOpened = false thinkTagClosed = true + state = &displayResponseState{} } displayResponse(content, opts.WordWrap, state) + if response.ToolCalls != nil { + toolCalls := response.ToolCalls + if len(toolCalls) > 0 { + fmt.Print(renderToolCalls(toolCalls, plainText)) + } + } + return nil } @@ -1463,7 +1506,8 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") - runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models") + runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models") + runCmd.Flags().Lookup("think").NoOptDefVal = "true" runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") stopCmd := &cobra.Command{ @@ -1613,7 +1657,7 @@ func NewCLI() *cobra.Command { // to false). // // If capabilities are not provided, we fetch them from the server. -func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) { +func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*api.ThinkValue, error) { if explicitlySetByUser { return runOpts.Think, nil } @@ -1640,9 +1684,34 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit } if thinkingSupported { - thinking := true - return &thinking, nil + return &api.ThinkValue{Value: true}, nil } return nil, nil } + +func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string { + out := "" + formatExplanation := "" + formatValues := "" + if !plainText { + formatExplanation = readline.ColorGrey + readline.ColorBold + formatValues = readline.ColorDefault + out += formatExplanation + } + for i, toolCall := range toolCalls { + argsAsJSON, err := json.Marshal(toolCall.Function.Arguments) + if err != nil { + return "" + } + if i > 0 { + out += "\n" + } + // all tool calls are unexpected since we don't currently support registering any in the CLI + out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation) + } + if !plainText { + out += readline.ColorDefault + } + return out +} diff --git a/cmd/interactive.go b/cmd/interactive.go index 08ab4947b..e290d84ce 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -272,16 +272,29 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } fmt.Println("Set 'quiet' mode.") case "think": - think := true - opts.Think = &think + thinkValue := api.ThinkValue{Value: true} + var maybeLevel string + if len(args) > 2 { + maybeLevel = args[2] + } + if maybeLevel != "" { + // TODO(drifkin): validate the level, could be model dependent + // though... It will also be validated on the server once a call is + // made. + thinkValue.Value = maybeLevel + } + opts.Think = &thinkValue thinkExplicitlySet = true if client, err := api.ClientFromEnvironment(); err == nil { ensureThinkingSupport(cmd.Context(), client, opts.Model) } - fmt.Println("Set 'think' mode.") + if maybeLevel != "" { + fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel) + } else { + fmt.Println("Set 'think' mode.") + } case "nothink": - think := false - opts.Think = &think + opts.Think = &api.ThinkValue{Value: false} thinkExplicitlySet = true if client, err := api.ClientFromEnvironment(); err == nil { ensureThinkingSupport(cmd.Context(), client, opts.Model) @@ -478,7 +491,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { assistant, err := chat(cmd, opts) if err != nil { - if strings.Contains(err.Error(), "does not support thinking") { + if strings.Contains(err.Error(), "does not support thinking") || + strings.Contains(err.Error(), "invalid think value") { fmt.Printf("error: %v\n", err) sb.Reset() continue diff --git a/convert/convert.go b/convert/convert.go index 63b3bf661..bed59a575 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &bertModel{} case "CohereForCausalLM": conv = &commandrModel{} + case "GptOssForCausalLM": + conv = &gptossModel{} default: return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go new file mode 100644 index 000000000..bd362169b --- /dev/null +++ b/convert/convert_gptoss.go @@ -0,0 +1,178 @@ +package convert + +import ( + "bytes" + "cmp" + "encoding/binary" + "io" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" +) + +type gptossModel struct { + ModelParameters + HiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + AttentionHeads uint32 `json:"num_attention_heads"` + KeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + Experts uint32 `json:"num_experts"` + ExpertsPerToken uint32 `json:"experts_per_token"` + RMSNormEpsilon float32 `json:"rms_norm_eps"` + InitialContextLength uint32 `json:"initial_context_length"` + RopeTheta float32 `json:"rope_theta"` + RopeScalingFactor float32 `json:"rope_scaling_factor"` + SlidingWindow uint32 `json:"sliding_window"` +} + +var _ ModelConverter = (*gptossModel)(nil) + +func (m *gptossModel) KV(t *Tokenizer) ggml.KV { + kv := m.ModelParameters.KV(t) + kv["general.architecture"] = "gptoss" + kv["general.file_type"] = uint32(4) + kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength)) + kv["gptoss.block_count"] = m.HiddenLayers + kv["gptoss.embedding_length"] = m.HiddenSize + kv["gptoss.feed_forward_length"] = m.IntermediateSize + kv["gptoss.expert_count"] = m.Experts + kv["gptoss.expert_used_count"] = m.ExpertsPerToken + kv["gptoss.attention.head_count"] = m.AttentionHeads + kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads + kv["gptoss.attention.key_length"] = m.HeadDim + kv["gptoss.attention.value_length"] = m.HeadDim + kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5) + kv["gptoss.attention.sliding_window"] = m.SlidingWindow + kv["gptoss.rope.freq_base"] = m.RopeTheta + kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor + kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength + kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|> + kv["tokenizer.ggml.add_bos_token"] = false + kv["tokenizer.ggml.eos_token_id"] = uint32(199999) // <|endoftext|> + kv["tokenizer.ggml.eos_token_ids"] = []int32{ + 199999, /* <|endoftext|> */ + 200002, /* <|return|> */ + 200012, /* <|call|> */ + } + kv["tokenizer.ggml.add_eos_token"] = false + return kv +} + +func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + mxfp4s := make(map[string]*mxfp4) + for _, t := range ts { + if strings.HasSuffix(t.Name(), ".blocks") || strings.HasSuffix(t.Name(), ".scales") { + dot := strings.LastIndex(t.Name(), ".") + name, suffix := t.Name()[:dot], t.Name()[dot+1:] + if _, ok := mxfp4s[name]; !ok { + mxfp4s[name] = &mxfp4{} + } + + switch suffix { + case "blocks": + mxfp4s[name].blocks = t + case "scales": + mxfp4s[name].scales = t + } + } else { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + } + + for name, mxfp4 := range mxfp4s { + dims := mxfp4.blocks.Shape() + out = append(out, &ggml.Tensor{ + Name: name, + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, + WriterTo: mxfp4, + }) + } + + return out +} + +func (m *gptossModel) Replacements() []string { + return []string{ + // noop replacements so other replacements will not be applied + ".blocks", ".blocks", + ".scales", ".scales", + // real replacements + "block", "blk", + "attn.norm", "attn_norm", + "attn.qkv", "attn_qkv", + "attn.sinks", "attn_sinks", + "attn.out", "attn_out", + "mlp.norm", "ffn_norm", + "mlp.gate", "ffn_gate_inp", + "mlp.mlp1_", "ffn_gate_up_exps.", + "mlp.mlp2_", "ffn_down_exps.", + "embedding", "token_embd", + "norm", "output_norm", + "unembedding", "output", + "scale", "weight", + } +} + +type mxfp4 struct { + blocks, scales Tensor +} + +func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { + var b bytes.Buffer + if _, err := m.blocks.WriteTo(&b); err != nil { + return 0, err + } + + blocksDims := make([]int, len(m.blocks.Shape())) + for i, d := range m.blocks.Shape() { + blocksDims[i] = int(d) + } + + var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes())) + + var s bytes.Buffer + if _, err := m.scales.WriteTo(&s); err != nil { + return 0, err + } + + scalesDims := slices.Repeat([]int{1}, len(m.blocks.Shape())) + for i, d := range m.scales.Shape() { + scalesDims[i] = int(d) + } + + var scales tensor.Tensor = tensor.New(tensor.WithShape(scalesDims...), tensor.WithBacking(s.Bytes())) + + out, err := tensor.Concat(3, scales, blocks) + if err != nil { + return 0, err + } + + out = tensor.Materialize(out) + + if err := out.Reshape(out.Shape().TotalSize()); err != nil { + return 0, err + } + + u8s, err := native.VectorU8(out.(*tensor.Dense)) + if err != nil { + return 0, err + } + + if err := binary.Write(w, binary.LittleEndian, u8s); err != nil { + return 0, err + } + + return 0, nil +} diff --git a/convert/reader.go b/convert/reader.go index 07d12f0dd..367e91a29 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -31,8 +31,10 @@ func (t tensorBase) Shape() []uint64 { } const ( - tensorKindF32 uint32 = iota - tensorKindF16 + tensorKindFP32 uint32 = iota + tensorKindFP16 + tensorKindMXFP4 = 4 + tensorKindBF16 = 30 ) func (t tensorBase) Kind() uint32 { @@ -43,16 +45,16 @@ func (t tensorBase) Kind() uint32 { t.name == "v.pre_tile_position_embd.weight" || t.name == "v.post_tile_position_embd.weight" { // these tensors are always F32 - return 0 + return tensorKindFP32 } switch len(t.shape) { case 0: panic("invalid tensor shape") case 1: - return tensorKindF32 + return tensorKindFP32 default: - return tensorKindF16 + return tensorKindFP16 } } diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index f182a656c..63f31631d 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -93,6 +93,15 @@ type safetensor struct { *tensorBase } +func (st safetensor) Kind() uint32 { + kind := st.tensorBase.Kind() + if st.dtype == "BF16" && kind != tensorKindFP32 { + kind = tensorKindBF16 + } + + return kind +} + func (st safetensor) Clone() Tensor { return &safetensor{ fs: st.fs, @@ -150,6 +159,9 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { } f32s = bfloat16.DecodeFloat32(u8s) + case "U8": + // U8 tensors do not support repacking or type conversion. + return io.CopyN(w, f, st.size) default: return 0, fmt.Errorf("unknown data type: %s", st.dtype) } @@ -162,15 +174,18 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { } switch st.Kind() { - case tensorKindF32: + case tensorKindFP32: return 0, binary.Write(w, binary.LittleEndian, f32s) - case tensorKindF16: + case tensorKindFP16: f16s := make([]uint16, len(f32s)) for i := range f32s { f16s[i] = float16.Fromfloat32(f32s[i]).Bits() } return 0, binary.Write(w, binary.LittleEndian, f16s) + case tensorKindBF16: + u8s := bfloat16.EncodeFloat32(f32s) + return 0, binary.Write(w, binary.LittleEndian, u8s) default: return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) } diff --git a/convert/tensor_test.go b/convert/tensor_test.go index 0b2db5baa..3a34bbff6 100644 --- a/convert/tensor_test.go +++ b/convert/tensor_test.go @@ -72,236 +72,787 @@ func mul(shape []uint64) int { } func TestSplitDim(t *testing.T) { - r := fakeTensor{ - name: "a.b", - shape: []uint64{3, 4}, - data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - } - - t.Run("no split", func(t *testing.T) { - for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) { - if tt.Name != "x.b" { - t.Fatalf("expected name 'x', got '%s'", tt.Name) - } - - if !slices.Equal(tt.Shape, []uint64{3, 4}) { - t.Fatalf("expected shape [3, 4], got %v", tt.Shape) - } - - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) - } - - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) - } - - if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) { - t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s) - } + t.Run("2d", func(t *testing.T) { + r := fakeTensor{ + name: "a.b", + shape: []uint64{3, 4}, + data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, } + + t.Run("no split", func(t *testing.T) { + for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) { + if tt.Name != "x.b" { + t.Fatalf("expected name 'x', got '%s'", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 4}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("even split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y")}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{2, 3, 6, 7, 10, 11}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("uneven split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 0, + split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{2, 4}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("three way split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 0, + split{Replacer: strings.NewReplacer("a", "x"), dim: 1}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + split{Replacer: strings.NewReplacer("b", "z"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.z" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("uneven three way split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + split{Replacer: strings.NewReplacer("b", "z"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{2, 6, 10}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.z" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{3, 7, 11}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("split with transpose", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) { + return tensor.Transpose(tt, 1, 0) + }}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{2, 6, 10, 3, 7, 11}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) }) - - t.Run("even split", func(t *testing.T) { - next, stop := iter.Pull(splitDim(&r, 1, - split{Replacer: strings.NewReplacer("a", "x")}, - split{Replacer: strings.NewReplacer("b", "y")}, - )) - defer stop() - - { - tt, ok := next() - if !ok { - t.Fatal("expected at least one split") - } - - if tt.Name != "x.b" { - t.Fatal("expected name 'x.b', got", tt.Name) - } - - if !slices.Equal(tt.Shape, []uint64{3, 2}) { - t.Fatal("expected shape [3, 2], got", tt.Shape) - } - - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) - } - - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) - } - - if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) { - t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s) - } + t.Run("3d", func(t *testing.T) { + r := fakeTensor{ + name: "a.b", + shape: []uint64{3, 4, 2}, + data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, } - { - tt, ok := next() - if !ok { - t.Fatal("expected at least one split") + t.Run("no split", func(t *testing.T) { + for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) { + if tt.Name != "x.b" { + t.Fatalf("expected name 'x', got '%s'", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 4, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("even split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x")}, + split{Replacer: strings.NewReplacer("b", "y")}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - if tt.Name != "a.y" { - t.Fatal("expected name 'a.y', got", tt.Name) + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("uneven split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 0, + split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{2, 4, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - if !slices.Equal(tt.Shape, []uint64{3, 2}) { - t.Fatal("expected shape [3, 2], got", tt.Shape) + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'a.y', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("three way split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 0, + split{Replacer: strings.NewReplacer("a", "x"), dim: 1}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + split{Replacer: strings.NewReplacer("b", "z"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.y" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11, 12, 13, 14, 15}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "a.z" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + }) + + t.Run("uneven three way split", func(t *testing.T) { + next, stop := iter.Pull(splitDim(&r, 1, + split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, + split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, + split{Replacer: strings.NewReplacer("b", "z"), dim: 1}, + )) + defer stop() + + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } + + if tt.Name != "x.b" { + t.Fatal("expected name 'x.b', got", tt.Name) + } + + if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) { - t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s) - } - } - }) + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } - t.Run("uneven split", func(t *testing.T) { - next, stop := iter.Pull(splitDim(&r, 0, - split{Replacer: strings.NewReplacer("a", "x"), dim: 2}, - split{Replacer: strings.NewReplacer("b", "y"), dim: 1}, - )) - defer stop() + if tt.Name != "a.y" { + t.Fatal("expected name 'x.b', got", tt.Name) + } - { - tt, ok := next() - if !ok { - t.Fatal("expected at least one split") + if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(f32s, []float32{4, 5, 12, 13, 20, 21}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - if tt.Name != "x.b" { - t.Fatal("expected name 'x.b', got", tt.Name) - } + { + tt, ok := next() + if !ok { + t.Fatal("expected at least one split") + } - if !slices.Equal(tt.Shape, []uint64{2, 4}) { - t.Fatal("expected shape [2, 4], got", tt.Shape) - } + if tt.Name != "a.z" { + t.Fatal("expected name 'x.b', got", tt.Name) + } - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) - } + if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" { + t.Errorf("unexpected shape (-want +got):\n%s", diff) + } - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) - } + var b bytes.Buffer + if _, err := tt.WriteTo(&b); err != nil { + t.Fatal(err) + } - if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) { - t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s) - } - } + f32s := make([]float32, mul(tt.Shape)) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } - { - tt, ok := next() - if !ok { - t.Fatal("expected at least one split") + if diff := cmp.Diff(f32s, []float32{6, 7, 14, 15, 22, 23}); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } } - - if tt.Name != "a.y" { - t.Fatal("expected name 'a.y', got", tt.Name) - } - - if !slices.Equal(tt.Shape, []uint64{1, 4}) { - t.Fatal("expected shape [1, 4], got", tt.Shape) - } - - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) - } - - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) - } - - if !slices.Equal(f32s, []float32{8, 9, 10, 11}) { - t.Fatal("expected data [8, 9, 10, 11], got", f32s) - } - } - }) - - t.Run("split with transpose", func(t *testing.T) { - next, stop := iter.Pull(splitDim(&r, 1, - split{Replacer: strings.NewReplacer("a", "x")}, - split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) { - return tensor.Transpose(tt, 1, 0) - }}, - )) - defer stop() - - { - tt, ok := next() - if !ok { - t.Fatal("expected at least one split") - } - - if tt.Name != "x.b" { - t.Fatal("expected name 'x.b', got", tt.Name) - } - - if !slices.Equal(tt.Shape, []uint64{3, 2}) { - t.Fatal("expected shape [3, 2], got", tt.Shape) - } - - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) - } - - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) - } - - if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) { - t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s) - } - } - - { - tt, ok := next() - if !ok { - t.Fatal("expected at least one split") - } - - if tt.Name != "a.y" { - t.Fatal("expected name 'a.y', got", tt.Name) - } - - if !slices.Equal(tt.Shape, []uint64{3, 2}) { - t.Fatal("expected shape [3, 2], got", tt.Shape) - } - - var b bytes.Buffer - if _, err := tt.WriteTo(&b); err != nil { - t.Fatal(err) - } - - f32s := make([]float32, mul(tt.Shape)) - if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { - t.Fatal(err) - } - - if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) { - t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s) - } - } + }) }) } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 32f459a3a..afb90720f 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -1,6 +1,7 @@ package ggml import ( + "cmp" "encoding/binary" "errors" "fmt" @@ -179,6 +180,7 @@ func (kv KV) OllamaEngineRequired() bool { "llama4", "mllama", "qwen25vl", + "gptoss", }, kv.Architecture()) } @@ -280,7 +282,7 @@ func (t Tensor) block() (n int) { } func (t Tensor) blockSize() uint64 { - return (TensorType)(t.Kind).BlockSize() + return TensorType(t.Kind).BlockSize() } func (t TensorType) BlockSize() uint64 { @@ -298,6 +300,7 @@ func (t TensorType) BlockSize() uint64 { case 2, // Q4_0 3, // Q4_1 + 4, // MXFP4 6, // Q5_0 7, // Q5_1 8, // Q8_0 @@ -325,6 +328,8 @@ func (t TensorType) TypeSize() uint64 { return 2 + blockSize/2 case TensorTypeQ4_1: return 2 + 2 + blockSize/2 + case TensorTypeMXFP4: + return 1 + blockSize/2 case TensorTypeQ5_0: return 2 + 4 + blockSize/2 case TensorTypeQ5_1: @@ -487,9 +492,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri layers := f.Tensors().GroupLayers() bytesPerElement := kvCacheBytesPerElement(kvCacheType) + var kvTotal uint64 kv = make([]uint64, f.KV().BlockCount()) for i := range kv { kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + kvTotal += kv[i] } switch f.KV().Architecture() { @@ -658,6 +665,18 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri 4*qkvBias.Shape[0], ) } + case "gptoss": + kv = make([]uint64, f.KV().BlockCount()) + for i := range kv { + kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + if i%2 == 0 { + kv[i] *= (uint64(numParallel)*4096 + batch) + } else { + kv[i] *= context + } + } + fullOffload = 4 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 + partialOffload = 2 * fullOffload } return diff --git a/fs/ggml/type.go b/fs/ggml/type.go index 4d3d5bcad..3e5deb87b 100644 --- a/fs/ggml/type.go +++ b/fs/ggml/type.go @@ -14,9 +14,9 @@ const ( FileTypeF16 fileTypeQ4_0 fileTypeQ4_1 - fileTypeQ4_1_F16 // unused by GGML - fileTypeQ4_2 // unused by GGML - fileTypeQ4_3 // unused by GGML + fileTypeMXFP4 // originally fileTypeQ4_1_F16 // unused by GGML + fileTypeQ4_2 // unused by GGML + fileTypeQ4_3 // unused by GGML FileTypeQ8_0 fileTypeQ5_0 fileTypeQ5_1 @@ -97,6 +97,8 @@ func (t FileType) String() string { return "Q4_0" case fileTypeQ4_1: return "Q4_1" + case fileTypeMXFP4: + return "MXFP4" case FileTypeQ8_0: return "Q8_0" case fileTypeQ5_0: @@ -144,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType { return TensorTypeQ4_0 case fileTypeQ4_1: return TensorTypeQ4_1 + case fileTypeMXFP4: + return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 case FileTypeQ8_0: return TensorTypeQ8_0 case fileTypeQ5_0: @@ -187,8 +191,8 @@ const ( TensorTypeF16 TensorTypeQ4_0 TensorTypeQ4_1 - tensorTypeQ4_2 // unused by GGML - tensorTypeQ4_3 // unused by GGML + TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 + tensorTypeQ4_3 // unused by GGML TensorTypeQ5_0 TensorTypeQ5_1 TensorTypeQ8_0 @@ -260,6 +264,8 @@ func ParseTensorType(s string) (TensorType, error) { return TensorTypeF64, nil case "BF16": return TensorTypeBF16, nil + case "MXFP4": + return TensorTypeMXFP4, nil default: return 0, fmt.Errorf("unsupported quantization type %s", s) } @@ -312,6 +318,8 @@ func (t TensorType) String() string { return "F64" case TensorTypeBF16: return "BF16" + case TensorTypeMXFP4: + return "MXFP4" default: return "unknown" } diff --git a/llama/patches/0019-metal-add-mean-kernel-14267.patch b/llama/patches/0019-metal-add-mean-kernel-14267.patch index f20e854b2..e65aeb7b4 100644 --- a/llama/patches/0019-metal-add-mean-kernel-14267.patch +++ b/llama/patches/0019-metal-add-mean-kernel-14267.patch @@ -19,7 +19,7 @@ diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index a9eeebc6..110c9ece 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m -@@ -489,6 +489,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte +@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -27,7 +27,7 @@ index a9eeebc6..110c9ece 100644 GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, GGML_METAL_KERNEL_TYPE_ARGMAX, -@@ -1436,6 +1437,7 @@ @implementation GGMLMetalClass +@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); diff --git a/llama/patches/0022-BF16-macos-version-guard.patch b/llama/patches/0022-BF16-macos-version-guard.patch index 68aac0bb0..88e4f7cb0 100644 --- a/llama/patches/0022-BF16-macos-version-guard.patch +++ b/llama/patches/0022-BF16-macos-version-guard.patch @@ -12,7 +12,7 @@ diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 110c9ece..ab46f6e3 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m -@@ -89,7 +89,11 @@ +@@ -89,7 +89,11 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; #if defined(GGML_METAL_USE_BF16) diff --git a/llama/patches/0023-MXFP4.patch b/llama/patches/0023-MXFP4.patch new file mode 100644 index 000000000..2beb1518d --- /dev/null +++ b/llama/patches/0023-MXFP4.patch @@ -0,0 +1,1293 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Mon, 21 Jul 2025 12:06:13 -0700 +Subject: [PATCH] MXFP4 + +Partial implementation of MXFP4 tensor type +--- + ggml/include/ggml.h | 2 +- + ggml/src/ggml-common.h | 7 + + ggml/src/ggml-cpu/ggml-cpu-quants.h | 2 + + ggml/src/ggml-cpu/ggml-cpu.c | 5 + + ggml/src/ggml-cpu/ops.cpp | 1 + + ggml/src/ggml-cpu/vec.cpp | 90 ++++++++ + ggml/src/ggml-cpu/vec.h | 2 + + ggml/src/ggml-cuda/convert.cu | 80 +++++++ + ggml/src/ggml-cuda/ggml-cuda.cu | 16 +- + ggml/src/ggml-cuda/mmvmxfp4.cu | 307 ++++++++++++++++++++++++++ + ggml/src/ggml-cuda/mmvmxfp4.cuh | 9 + + ggml/src/ggml-metal/ggml-metal-impl.h | 3 + + ggml/src/ggml-metal/ggml-metal.m | 25 ++- + ggml/src/ggml-metal/ggml-metal.metal | 173 ++++++++++++++- + ggml/src/ggml-quants.c | 142 +++++++++++- + ggml/src/ggml-quants.h | 6 + + ggml/src/ggml.c | 13 +- + 17 files changed, 868 insertions(+), 15 deletions(-) + create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cu + create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cuh + +diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h +index e91dedf1..873baa24 100644 +--- a/ggml/include/ggml.h ++++ b/ggml/include/ggml.h +@@ -353,7 +353,7 @@ extern "C" { + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, +- // GGML_TYPE_Q4_2 = 4, support has been removed ++ GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2 + // GGML_TYPE_Q4_3 = 5, support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, +diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h +index 086c822d..e0d71451 100644 +--- a/ggml/src/ggml-common.h ++++ b/ggml/src/ggml-common.h +@@ -417,6 +417,13 @@ typedef struct { + } block_iq4_xs; + static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); + ++#define MXFP4 32 ++typedef struct { ++ uint8_t d; // scale E8M0 float ++ uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float ++} block_mxfp4; ++static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); ++ + #endif // GGML_COMMON_DECL + #endif // GGML_COMMON_DECL + +diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ggml/src/ggml-cpu/ggml-cpu-quants.h +index e33d9d47..6a25d062 100644 +--- a/ggml/src/ggml-cpu/ggml-cpu-quants.h ++++ b/ggml/src/ggml-cpu/ggml-cpu-quants.h +@@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const + void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + ++void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); ++ + #ifdef __cplusplus + } + #endif +diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c +index 2462d2b8..bff9c426 100644 +--- a/ggml/src/ggml-cpu/ggml-cpu.c ++++ b/ggml/src/ggml-cpu/ggml-cpu.c +@@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, ++ [GGML_TYPE_MXFP4] = { ++ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4, ++ .vec_dot_type = GGML_TYPE_F32, ++ .nrows = 1, ++ }, + }; + + const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { +diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp +index 654e2f28..be0aa683 100644 +--- a/ggml/src/ggml-cpu/ops.cpp ++++ b/ggml/src/ggml-cpu/ops.cpp +@@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp( + case GGML_TYPE_I32: + case GGML_TYPE_I64: + case GGML_TYPE_F64: ++ case GGML_TYPE_MXFP4: + case GGML_TYPE_COUNT: + { + GGML_ABORT("fatal error"); +diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp +index 02d40618..ec3ec9b1 100644 +--- a/ggml/src/ggml-cpu/vec.cpp ++++ b/ggml/src/ggml-cpu/vec.cpp +@@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl + } + return sum = (ggml_float)logf(sum); + } ++ ++#define MXFP4 32 ++typedef struct { ++ uint8_t d; // scale E8M0 float ++ uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float ++} block_mxfp4; ++static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); ++#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0} ++ ++void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) { ++ assert(nrc == 1); ++ GGML_UNUSED(nrc); ++ GGML_UNUSED(bx); ++ GGML_UNUSED(by); ++ GGML_UNUSED(bs); ++ ggml_float mxfp4_table[] = MXFP4_VALS; ++ ++#if defined(GGML_SIMD) ++ float sumf = 0.0f; ++ const int np = (n & ~(GGML_F32_STEP - 1)); ++ const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx; ++ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; ++ ++ GGML_F32_VEC scalev; ++ GGML_F32_VEC ax[GGML_F32_ARR]; ++ GGML_F32_VEC ay[GGML_F32_ARR]; ++ for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64 ++ for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4 ++ // convert GGML_F32_ARR X elements ++ const int ib = (i + j*GGML_F32_EPR) / MXFP4; ++ const block_mxfp4 * GGML_RESTRICT x = &xx[ib]; ++ union { ++ uint32_t as_bits; ++ float as_value; ++ } scale; ++ scale.as_bits = (((uint32_t)x->d) << 23); ++ scalev = GGML_F32_VEC_SET1(scale.as_value); ++ float xf[GGML_F32_EPR]= {0.f}; ++ assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun"); ++ for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) { ++ xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)]; ++ xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4]; ++ } ++ ++ ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev); ++ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); ++ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); ++ } ++ } ++ GGML_F32_VEC_REDUCE(sumf, sum); ++ ++ // leftovers ++ for (int i = np; i < n; i+=2) { ++ const int ib = i / MXFP4; ++ const block_mxfp4 * GGML_RESTRICT x = &xx[ib]; ++ union { ++ uint32_t as_bits; ++ float as_value; ++ } scale; ++ scale.as_bits = (((uint32_t)x->d) << 23); ++ sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)]; ++ sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4]; ++ } ++ ++ ++#else // defined(GGML_SIMD) ++ const int nb = n / MXFP4; ++ assert(n % MXFP4 == 0); ++ ++ int yi = 0; ++ ++ const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx; ++ ++ ggml_float sumf = 0.0; ++ for (int ib = 0; ib < nb; ++ib) { ++ const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0]; ++ union { ++ uint32_t as_bits; ++ float as_value; ++ } scale; ++ scale.as_bits = (((uint32_t)x->d) << 23); ++ for (int i = 0; i < MXFP4/2; ++i) { ++ sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]); ++ sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]); ++ } ++ } ++#endif ++ ++ *s = sumf; ++} +diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h +index 23cbb305..7480ca08 100644 +--- a/ggml/src/ggml-cpu/vec.h ++++ b/ggml/src/ggml-cpu/vec.h +@@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G + void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); + void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); + ++void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); ++ + void ggml_vec_silu_f32(const int n, float * y, const float * x); + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); + ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); +diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu +index c6dec427..0e016ccc 100644 +--- a/ggml/src/ggml-cuda/convert.cu ++++ b/ggml/src/ggml-cuda/convert.cu +@@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t + dequantize_block_iq4_xs<<>>(vx, y); + } + ++// MXFP4 dequantize derived from dequantize_block_q4_0 ++template ++static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { ++ const uint16_t dst_bias = 15; ++ const uint16_t dst_0p5 = 0x3800; ++ const uint16_t dst_m_bits = 10; ++ const int64_t i = blockIdx.x; ++ ++ // assume 32 threads ++ const int64_t tid = threadIdx.x; ++ const int64_t il = tid/8; ++ const int64_t ir = tid%8; ++ const int64_t ib = 8*i + ir; ++ if (ib >= nb32) { ++ return; ++ } ++ ++ const uint64_t offset = 256*i + MXFP4*ir + 8*il; ++ dst_t * y = yy + offset; ++ ++ const block_mxfp4 * x = (const block_mxfp4 *)vx + ib; ++ union { ++ uint32_t as_bits; ++ float as_value; ++ } scale; ++ scale.as_bits = (((uint32_t)x->d) << 23); ++ ++ // offset within the block 1/4 chunks (8 items) ++ const uint8_t * q = x->qs + 4*il; ++ ++ for (int l = 0; l < 4; ++l) { ++ uint16_t em0 = q[l] & 0x07; ++ uint16_t em1 = q[l] & 0x70; ++ // float16 values ++ iq1m_scale_t x0; ++ iq1m_scale_t x1; ++ ++ x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12); ++ x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8); ++ ++ // Three cases: ++ // x is normal and non-zero: Correct bias ++ if ((em0 & 0x06) != 0) { ++ x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); ++ } ++ if ((em1 & 0x60) != 0) { ++ x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits); ++ } ++ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type ++ if (em0 == 0x01) { ++ x0.u16 = dst_0p5 | (x0.u16 & 0x8000); ++ } ++ if (em1 == 0x10) { ++ x1.u16 = dst_0p5 | (x1.u16 & 0x8000); ++ } ++ // x is zero, do nothing ++ ++ // XXX it looks correct here - but mulmat still gives bad results... ++ // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n", ++ // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16)); ++ // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n", ++ // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16)); ++ ++ y[l*2] = scale.as_value * float(x0.f16); ++ y[l*2+1] = scale.as_value * float(x1.f16); ++ } ++} ++ ++// derived from dequantize_row_q4_0_cuda ++template ++static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { ++ const int nb32 = k / 32; ++ const int nb = (k + 255) / 256; ++ dequantize_block_mxfp4<<>>(vx, y, nb32); ++} ++ + template + static __global__ void convert_unary( + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, +@@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { + return convert_unary_cont_cuda; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda; ++ case GGML_TYPE_MXFP4: ++ return dequantize_row_mxfp4_cuda; + default: + return nullptr; + } +@@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { + return convert_unary_cont_cuda; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda; ++ case GGML_TYPE_MXFP4: ++ return dequantize_row_mxfp4_cuda; + default: + return nullptr; + } +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index 28ccf4be..bb19b06e 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -21,6 +21,7 @@ + #include "ggml-cuda/im2col.cuh" + #include "ggml-cuda/mmq.cuh" + #include "ggml-cuda/mmv.cuh" ++#include "ggml-cuda/mmvmxfp4.cuh" + #include "ggml-cuda/mmvq.cuh" + #include "ggml-cuda/norm.cuh" + #include "ggml-cuda/opt-step-adamw.cuh" +@@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas( + + const int cc = ggml_cuda_info().devices[id].cc; + +- const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; ++ const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4; + + if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); +@@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 +- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; ++ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE ++ && src0->type != GGML_TYPE_MXFP4; ++ bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4 ++ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 ++ && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + +@@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); + } else if (use_mul_mat_q) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); ++ } else if (use_mul_mat_vec_mxfp4) { ++ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr); + } else { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); + } +@@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) { ++ ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst); ++ return; ++ } + if (ne2 == 1) { + if (ggml_is_quantized(src0->type)) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); +@@ -3056,6 +3067,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_BF16: ++ case GGML_TYPE_MXFP4: + #ifdef GGML_USE_MUSA + if (a->type == GGML_TYPE_Q3_K) { + return false; +diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cu b/ggml/src/ggml-cuda/mmvmxfp4.cu +new file mode 100644 +index 00000000..da62062b +--- /dev/null ++++ b/ggml/src/ggml-cuda/mmvmxfp4.cu +@@ -0,0 +1,307 @@ ++#include "ggml.h" ++#include "common.cuh" ++#include "mmvmxfp4.cuh" ++ ++// MXFP4 implementation derived from mmv.cu float32 code paths ++typedef union { ++ half f16; ++ uint16_t u16; ++} f16_t; ++ ++template // TODO type_acc unused - consider bf16 support ++static __global__ void mul_mat_vec_mxfp4( ++ const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, ++ const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, ++ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, ++ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { ++ const int64_t row = blockIdx.x; ++ const int64_t channel_dst = blockIdx.y; ++ const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; ++ const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; ++ const int64_t sample_dst = blockIdx.z; ++ const int64_t sample_x = sample_dst / sample_ratio; ++ const int64_t sample_y = sample_dst; ++ const int tid = threadIdx.x; ++ constexpr int warp_size = ggml_cuda_get_physical_warp_size(); ++ ++ const uint16_t dst_bias = 15; ++ const uint16_t dst_0p5 = 0x3800; ++ const uint16_t dst_m_bits = 10; ++ ++ x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; ++ y += sample_y *stride_sample_y + channel_y *stride_channel_y; ++ dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; ++ ++ const float2 * y2 = (const float2 *) y; ++ ++ extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float) ++ float * buf_iw = (float *) data_mmv; ++ ++ if (block_size > warp_size) { ++ if (tid < warp_size) { ++ buf_iw[tid] = 0.0f; ++ } ++ __syncthreads(); ++ } ++ ++ float sumf = 0.0f; ++ ++ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { ++ int offset0 = col2 / (MXFP4/2); ++ int i = col2 % (MXFP4/2); ++ const block_mxfp4 *x2 = x+offset0; ++ ++ union { ++ uint32_t as_bits; ++ float as_value; ++ } scale; ++ scale.as_bits = (((uint32_t)x2->d) << 23); ++ uint16_t em0 = x2->qs[i] & 0x07; ++ uint16_t em1 = x2->qs[i] & 0x70; ++ // float16 values ++ f16_t x0; ++ f16_t x1; ++ x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12); ++ x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8); ++ ++ // Three cases: ++ // x is normal and non-zero: Correct bias ++ if ((em0 & 0x06) != 0) { ++ x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); ++ } ++ if ((em1 & 0x60) != 0) { ++ x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits); ++ } ++ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type ++ if (em0 == 0x01) { ++ x0.u16 = dst_0p5 | (x0.u16 & 0x8000); ++ } ++ if (em1 == 0x10) { ++ x1.u16 = dst_0p5 | (x1.u16 & 0x8000); ++ } ++ // x is zero, do nothing ++ ++ if (isnan(scale.as_value)) { ++ sumf = scale.as_value; ++ break; ++ } ++ ++ const float2 tmpx = {x0.f16, x1.f16}; ++ const float2 tmpy = y2[col2]; ++ sumf += tmpx.x*tmpy.x*scale.as_value; ++ sumf += tmpx.y*tmpy.y*scale.as_value; ++ } ++ ++ sumf = warp_reduce_sum(sumf); ++ ++ if (block_size > warp_size) { ++ buf_iw[tid/warp_size] = sumf; ++ __syncthreads(); ++ if (tid >= warp_size) { ++ return; ++ } ++ sumf = buf_iw[tid]; ++ sumf = warp_reduce_sum(sumf); ++ } ++ ++ if (tid != 0) { ++ return; ++ } ++ ++ dst[row] = sumf; ++} ++ ++template ++static void launch_mul_mat_vec_cuda_mxfp4( ++ const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst, ++ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, ++ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, ++ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, ++ cudaStream_t stream) { ++ GGML_ASSERT(ncols % 2 == 0); ++ // GGML_ASSERT(stride_row % 2 == 0); // TODO ++ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); ++ GGML_ASSERT( nsamples_dst % nsamples_x == 0); ++ const int64_t channel_ratio = nchannels_dst / nchannels_x; ++ const int64_t sample_ratio = nsamples_dst / nsamples_x; ++ int device; ++ int warp_size; ++ ++ CUDA_CHECK(cudaGetDevice(&device)); ++ warp_size = ggml_cuda_info().devices[device].warp_size; ++ ++ int64_t block_size_best = warp_size; ++ int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); ++ int64_t max_block_size = 256; ++ if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { ++ max_block_size = 128; ++ } ++ for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { ++ const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); ++ if (niter < niter_best) { ++ niter_best = niter; ++ block_size_best = block_size; ++ } ++ } ++ ++ const int smem = warp_size*sizeof(float); ++ const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); ++ const dim3 block_dims(block_size_best, 1, 1); ++ ++ switch (block_size_best) { ++ case 32: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 64: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 96: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 128: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 160: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 192: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 224: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ case 256: { ++ mul_mat_vec_mxfp4<<>> ++ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, ++ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); ++ } break; ++ default: { ++ GGML_ABORT("fatal error"); ++ } break; ++ } ++} ++ ++static void mul_mat_vec_cuda_mxfp4( ++ const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst, ++ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, ++ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, ++ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, ++ enum ggml_prec prec, cudaStream_t stream) { ++ launch_mul_mat_vec_cuda_mxfp4 ++ (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, ++ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); ++} ++ ++void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { ++ GGML_ASSERT( src1->type == GGML_TYPE_F32); ++ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); ++ GGML_ASSERT( dst->type == GGML_TYPE_F32); ++ ++ GGML_TENSOR_BINARY_OP_LOCALS; ++ ++ const size_t ts_src0 = ggml_type_size(src0->type); ++ const size_t ts_src1 = ggml_type_size(src1->type); ++ const size_t ts_dst = ggml_type_size(dst->type); ++ ++ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. ++ GGML_ASSERT(ne13 == ne3); ++ ++ // GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic ++ GGML_ASSERT( nb10 == ts_src1); ++ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); ++ GGML_ASSERT( nb0 == ts_dst); ++ ++ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; ++ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; ++ ++ const float * src1_d = (const float *) src1->data; ++ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; ++ float * dst_d = (float *) dst->data; ++ ++ const int64_t stride_row = src0->nb[1] / ts_src0; ++ const int64_t s11 = src1->nb[1] / ts_src1; ++ const int64_t s1 = dst->nb[1] / ts_dst; ++ const int64_t stride_channel_x = src0->nb[2] / ts_src0; ++ const int64_t s12 = src1->nb[2] / ts_src1; ++ const int64_t s2 = dst->nb[2] / ts_dst; ++ const int64_t stride_sample_x = src0->nb[3] / ts_src0; ++ const int64_t stride_sample_y = src1->nb[3] / ts_src1; ++ const int64_t stride_sample_dst = dst->nb[3] / ts_dst; ++ const int64_t nsamples_dst = ne3; ++ const int64_t nsamples_x = ne03; ++ const int64_t nchannels_x = ne02; ++ const int64_t nrows = ne01; ++ const int64_t ncols = ne00; ++ ++ // For MUL_MAT_ID the memory layout is different than for MUL_MAT: ++ const int64_t ncols_dst = ids ? ne2 : ne1; ++ const int64_t nchannels_y = ids ? ne11 : ne12; ++ const int64_t nchannels_dst = ids ? ne1 : ne2; ++ const int64_t stride_channel_dst = ids ? s1 : s2; ++ const int64_t stride_channel_y = ids ? s11 : s12; ++ ++ GGML_ASSERT(ncols_dst == 1); ++ ++ const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data; ++ mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row, ++ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, ++ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream()); ++} ++ ++void ggml_cuda_op_mul_mat_vec_mxfp4( ++ ggml_backend_cuda_context & ctx, ++ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, ++ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, ++ const int64_t src1_padded_row_size, cudaStream_t stream) { ++ ++ GGML_ASSERT(src1->type == GGML_TYPE_F32); ++ GGML_ASSERT(dst->type == GGML_TYPE_F32); ++ ++ const int64_t ne00 = src0->ne[0]; ++ const int64_t row_diff = row_high - row_low; ++ ++ GGML_ASSERT(src1_ncols == 1); ++ ++ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; ++ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; ++ ++ // ggml_cuda_op provides single, contiguous matrices ++ const int64_t stride_row = ne00 / MXFP4; ++ const int64_t nchannels_x = 1; ++ const int64_t nchannels_y = 1; ++ const int64_t nchannels_dst = 1; ++ const int64_t stride_channel_x = 0; ++ const int64_t stride_channel_y = 0; ++ const int64_t stride_channel_dst = 0; ++ const int64_t nsamples_x = 1; ++ const int64_t nsamples_dst = 1; ++ const int64_t stride_sample_x = 0; ++ const int64_t stride_sample_y = 0; ++ const int64_t stride_sample_dst = 0; ++ ++ const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i; ++ mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, ++ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, ++ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); ++ ++ GGML_UNUSED(ctx); ++ GGML_UNUSED(src1); ++ GGML_UNUSED(dst); ++ GGML_UNUSED(src1_ddq_i); ++ GGML_UNUSED(src1_ncols); ++ GGML_UNUSED(src1_padded_row_size); ++} +diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cuh b/ggml/src/ggml-cuda/mmvmxfp4.cuh +new file mode 100644 +index 00000000..a08fc780 +--- /dev/null ++++ b/ggml/src/ggml-cuda/mmvmxfp4.cuh +@@ -0,0 +1,9 @@ ++#include "common.cuh" ++ ++void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); ++ ++void ggml_cuda_op_mul_mat_vec_mxfp4( ++ ggml_backend_cuda_context & ctx, ++ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, ++ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, ++ const int64_t src1_padded_row_size, cudaStream_t stream); +diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h +index 17eab976..938386ba 100644 +--- a/ggml/src/ggml-metal/ggml-metal-impl.h ++++ b/ggml/src/ggml-metal/ggml-metal-impl.h +@@ -65,6 +65,9 @@ + #define N_R0_IQ4_XS 2 + #define N_SG_IQ4_XS 2 + ++#define N_R0_MXFP4 4 ++#define N_SG_MXFP4 2 ++ + // kernel argument structs + // + // - element counters (e.g. ne00) typically use int32_t to reduce register usage +diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m +index ab46f6e3..d8e05a21 100644 +--- a/ggml/src/ggml-metal/ggml-metal.m ++++ b/ggml/src/ggml-metal/ggml-metal.m +@@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; + static struct ggml_backend_reg g_ggml_backend_metal_reg; + static struct ggml_backend_device g_ggml_backend_metal_device; + ++ + // information about a Metal device + // note: assumes single GPU device - the default one + // TODO: support multiple GPU devices +@@ -209,6 +210,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, ++ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, +@@ -288,6 +290,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, ++ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, +@@ -310,6 +313,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, ++ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, +@@ -334,6 +338,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, ++ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, +@@ -934,7 +939,7 @@ static id ggml_metal_load_library(id device, bool use_bfl + + MTLCompileOptions * options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; +- ++ + //[options setFastMathEnabled:false]; + + metal_library = [device newLibraryWithSource:src options:options error:&error]; +@@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); +@@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); +@@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); +@@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); +@@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node( + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; ++ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + +@@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node( + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; + } break; ++ case GGML_TYPE_MXFP4: ++ { ++ nsg = N_SG_MXFP4; ++ nr0 = N_R0_MXFP4; ++ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; ++ } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); +@@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node( + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; ++ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + +@@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node( + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; ++ case GGML_TYPE_MXFP4: ++ { ++ nsg = N_SG_MXFP4; ++ nr0 = N_R0_MXFP4; ++ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; ++ } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index 08e8d807..69fa17de 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl( + device const char * src1, + device char * dst, + threadgroup char * shmem, +- uint3 tgpig, +- ushort tiisg, +- ushort sgitg) { +- const int nb = args.ne00/QK4_0; ++ uint3 tgpig, // Threadgroup Position in Grid ++ ushort tiisg, // Thread Index in SIMD Group ++ ushort sgitg) { // SIMD Group Index in ThreadGroup ++ const int nb = args.ne00/QK4_0; // src0->ne[0] / 32 + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + +- const int first_row = (r0 * nsg + sgitg) * nr0; ++ const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4 + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; +@@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id( + } + } + ++template ++void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { ++ float4x4 reg_f; ++ const ushort dst_bias = 15; ++ const ushort dst_0p5 = 0x3800; ++ const ushort dst_m_bits = 10; ++ const half scale = (half)(as_type(((uint32_t)xb->d) << 23)); ++ // il:0 first 16, il:1 last 16 ++ for (int i = 0; i < 8; i++) { ++ ushort em0 = xb->qs[il*8 + i] & 0x07; ++ ushort em1 = xb->qs[il*8 + i] & 0x70; ++ // float16 values ++ ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12); ++ ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8); ++ ++ // Three cases: ++ // x is normal and non-zero: Correct bias ++ if ((em0 & 0x06) != 0) { ++ x0 = x0 + ((dst_bias - 1) << dst_m_bits); ++ } ++ if ((em1 & 0x60) != 0) { ++ x1 = x1 + ((dst_bias - 1) << dst_m_bits); ++ } ++ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type ++ if (em0 == 0x01) { ++ x0 = dst_0p5 | (x0 & 0x8000); ++ } ++ if (em1 == 0x10) { ++ x1 = dst_0p5 | (x1 & 0x8000); ++ } ++ // x is zero, do nothing ++ ++ if (isnan(scale)) { ++ reg_f[i/2][2*(i%2) + 0] = scale; ++ reg_f[i/2][2*(i%2) + 1] = scale; ++ } else { ++ reg_f[i/2][2*(i%2) + 0] = scale * as_type(x0); ++ reg_f[i/2][2*(i%2) + 1] = scale * as_type(x1); ++ } ++ } ++ reg = (type4x4) reg_f; ++} ++ + #define QK_NL 16 + + // +@@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m + template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; + template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; + ++template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; ++ + // + // indirect matrix-matrix multiplication + // +@@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m + template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; + template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; + ++template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; ++ + + // + // matrix-vector multiplication +@@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id( + sgitg); + } + ++// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y ++void mul_mv_mxfp4_f32_impl( ++ ggml_metal_kargs_mul_mv args, ++ device const char * src0, ++ device const char * src1, ++ device char * dst, ++ threadgroup char * shmem, ++ uint3 tgpig, ++ ushort tiisg, ++ ushort sgitg) { ++ const ushort dst_bias = 15; ++ const ushort dst_0p5 = 0x3800; ++ const ushort dst_m_bits = 10; ++ const int nr0 = N_R0_MXFP4; ++ const int nsg = N_SG_MXFP4; ++ const int nw = N_SIMDWIDTH; ++ const int nb = args.ne00/MXFP4; ++ ++ const int r0 = tgpig.x; ++ const int r1 = tgpig.y; ++ const int im = tgpig.z; ++ ++ const int first_row = (r0 * nsg + sgitg) * nr0; ++ ++ const uint i12 = im%args.ne12; ++ const uint i13 = im/args.ne12; ++ ++ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; ++ ++ device const float * y = (device const float *) (src1 + offset1); ++ ++ // pointers to src0 rows ++ device const block_mxfp4 * ax[nr0]; ++ for (int row = 0; row < nr0; ++row) { ++ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ++ ++ ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0); ++ } ++ ++ float yl[16]; // src1 vector cache ++ float sumf[nr0] = {0.f}; ++ ++ const short ix = (tiisg/2); ++ const short il = (tiisg%2)*16; ++ ++ device const float * yb = y + ix*MXFP4 + il; ++ ++ // each thread in a SIMD group deals with half a block. ++ for (int ib = ix; ib < nb; ib += nw/2) { ++ ++#pragma unroll ++ for (short row = 0; row < nr0; row++) { ++ // Processes 16 items ++ device const block_mxfp4 * qb_curr = ax[row] + ib; ++ float d = as_type(((uint32_t)(ax[row] + ib)->d) << 23); ++ // il = 0 or 16 ++ device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2); ++ for (int i = 0; i < 8; ++i) { ++ ushort em0 = qs[i] & 0x07; ++ ushort em1 = qs[i] & 0x70; ++ ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12); ++ ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8); ++ // Three cases: ++ // x is normal and non-zero: Correct bias ++ if ((em0 & 0x06) != 0) { ++ x0 = x0 + ((dst_bias - 1) << dst_m_bits); ++ } ++ if ((em1 & 0x60) != 0) { ++ x1 = x1 + ((dst_bias - 1) << dst_m_bits); ++ } ++ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type ++ if (em0 == 0x01) { ++ x0 = dst_0p5 | (x0 & 0x8000); ++ } ++ if (em1 == 0x10) { ++ x1 = dst_0p5 | (x1 & 0x8000); ++ } ++ // x is zero, do nothing ++ if (!isnan(d)) { ++ sumf[row] += yb[i*2] * as_type(x0) * d ++ + yb[i*2+1] * as_type(x1) * d; ++ } else { ++ sumf[row] = d; ++ } ++ } ++ } ++ ++ yb += MXFP4 * 16; ++ } ++ ++ device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; ++ ++ for (int row = 0; row < nr0; ++row) { ++ const float tot = simd_sum(sumf[row]); ++ ++ if (tiisg == 0 && first_row + row < args.ne01) { ++ dst_f32[first_row + row] = tot; ++ } ++ } ++} ++ ++[[host_name("kernel_mul_mv_mxfp4_f32")]] ++kernel void kernel_mul_mv_mxfp4_f32( ++ constant ggml_metal_kargs_mul_mv & args, ++ device const char * src0, ++ device const char * src1, ++ device char * dst, ++ threadgroup char * shmem [[threadgroup(0)]], ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ ushort tiisg[[thread_index_in_simdgroup]], ++ ushort sgitg[[simdgroup_index_in_threadgroup]]) { ++ mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); ++} ++ + typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; + + template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +@@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t + template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + ++template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; ++ + kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, +diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c +index 84ec6dfe..17c308aa 100644 +--- a/ggml/src/ggml-quants.c ++++ b/ggml/src/ggml-quants.c +@@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE + quantize_iq2_s(x, y, 1, k, NULL); + } + ++// =============================== mxfp4 (de)-quantization ++ ++void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { ++ static const int qk = MXFP4; ++ static const uint32_t E8_BIAS = 127; ++ static const uint32_t E2_BIAS = 1; ++ ++ assert(k % qk == 0); ++ ++ const int nb = k / qk; ++ ++ for (int i = 0; i < nb; i++) { ++ float amax = 0.0f; // absolute max ++ ++ for (int j = 0; j < qk; j++) { ++ const float v = x[i*qk + j]; ++ if (amax < fabsf(v)) { ++ amax = fabsf(v); ++ } ++ } ++ ++ const float dequant_scale = amax / 6.0f; ++ uint32_t dequant_scale_exponent = 0; ++ memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent)); ++ ++ // Rounding up ++ dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000; ++ // Rounding down ++ // dequant_scale_exponent = dequant_scale_exponent & 0x7F800000; ++ ++ float dequant_scale_rounded = 0.0f; ++ memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded)); ++ float quant_scale = 0.0f; ++ if (dequant_scale_rounded != 0.0f) { ++ quant_scale = 1.0f / dequant_scale_rounded; ++ } ++ ++ y[i].d = (uint8_t)(dequant_scale_exponent >> 23); ++ ++ for (int j = 0; j < qk/2; ++j) { ++ const float x0 = x[i*qk + j*2]*quant_scale; ++ const float x1 = x[i*qk + j*2+1]*quant_scale; ++ ++ uint32_t xi0 = 0; ++ uint32_t xi1 = 0; ++ memcpy(&xi0, &x0, sizeof(xi0)); ++ memcpy(&xi1, &x1, sizeof(xi1)); ++ ++ uint32_t s0 = xi0 & 0x80000000; ++ uint32_t s1 = xi1 & 0x80000000; ++ uint32_t e0 = (xi0 >> 23) & 0xFF; ++ uint32_t e1 = (xi1 >> 23) & 0xFF; ++ uint32_t m0 = (xi0 & 0x7FFFFF); ++ uint32_t m1 = (xi1 & 0x7FFFFF); ++ ++ // 0.25 <= x < 0.75 maps to 0.5, a denormal number ++ // Move implicit bit 1 at the beginning to mantissa for denormals ++ // adjusted_exponents ++ uint32_t ae0 = E8_BIAS - (e0 + 1); ++ uint32_t ae1 = E8_BIAS - (e1 + 1); ++ if (e0 < E8_BIAS) { ++ m0 = (0x400000 | (m0 >> 1)) >> ae0; ++ } ++ if (e1 < E8_BIAS) { ++ m1 = (0x400000 | (m1 >> 1)) >> ae1; ++ } ++ ++ // For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. ++ e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS); ++ e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS); ++ ++ // Combine sign, exponent, and mantissa, while saturating ++ // rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right ++ uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7); ++ uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7); ++ uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0); ++ uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1); ++ y[i].qs[j] = v0; ++ y[i].qs[j] |= v1 << 4; ++ } ++ } ++} ++ ++void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { ++ assert(k % MXFP4 == 0); ++ ++ const int nb = k / MXFP4; ++ const uint16_t dst_bias = 15; ++ const uint16_t dst_0p5 = 0x3800; ++ const uint16_t dst_m_bits = 10; ++ ++ for (int i = 0; i < nb; i++) { ++ union { ++ uint32_t as_bits; ++ float as_value; ++ } scale; ++ scale.as_bits = (((uint32_t)x[i].d) << 23); ++ for (int j = 0; j < MXFP4/2; ++j) { ++ uint16_t em0 = x[i].qs[j] & 0x07; ++ uint16_t em1 = x[i].qs[j] & 0x70; ++ // float16 values ++ uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12); ++ uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8); ++ ++ // Three cases: ++ // x is normal and non-zero: Correct bias ++ if ((em0 & 0x06) != 0) { ++ x0 = x0 + ((dst_bias - 1) << dst_m_bits); ++ } ++ if ((em1 & 0x60) != 0) { ++ x1 = x1 + ((dst_bias - 1) << dst_m_bits); ++ } ++ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type ++ if (em0 == 0x01) { ++ x0 = dst_0p5 | (x0 & 0x8000); ++ } ++ if (em1 == 0x10) { ++ x1 = dst_0p5 | (x1 & 0x8000); ++ } ++ // x is zero, do nothing ++ ++ if (isnan(scale.as_value)) { ++ y[i*MXFP4 + j*2] = scale.as_value; ++ y[i*MXFP4 + j*2+1] = scale.as_value; ++ } else { ++ y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value; ++ y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value; ++ } ++ } ++ } ++} ++ ++ ++size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { ++ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); ++ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); ++} ++ + // =============================== data validation + + static bool validate_float(float f, size_t i) { +@@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); + } break; +- ++ case GGML_TYPE_MXFP4: ++ // TODO - anything to validate? ++ break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: +diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h +index d09173e1..2fc40f75 100644 +--- a/ggml/src/ggml-quants.h ++++ b/ggml/src/ggml-quants.h +@@ -37,6 +37,8 @@ GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_ + GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); + ++GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); ++ + // Dequantization + GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +@@ -65,6 +67,8 @@ GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, floa + GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + ++GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); ++ + // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") + GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR + GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + ++GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); ++ + GGML_API void iq2xs_init_impl(enum ggml_type type); + GGML_API void iq2xs_free_impl(enum ggml_type type); + GGML_API void iq3xs_init_impl(int grid_size); +diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c +index 8a654624..0f3c9834 100644 +--- a/ggml/src/ggml.c ++++ b/ggml/src/ggml.c +@@ -589,11 +589,13 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { + .to_float = (ggml_to_float_t) dequantize_row_q4_1, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, + }, +- [4] = { // GGML_TYPE_Q4_2 +- .type_name = "DEPRECATED", +- .blck_size = 0, +- .type_size = 0, +- .is_quantized = false, ++ [GGML_TYPE_MXFP4] = { // formerly deprecated GGML_TYPE_Q4_2 ++ .type_name = "mxfp4", ++ .blck_size = MXFP4, ++ .type_size = sizeof(block_mxfp4), ++ .is_quantized = true, ++ .to_float = (ggml_to_float_t) dequantize_row_mxfp4, ++ .from_float_ref = (ggml_from_float_t) quantize_row_mxfp4_ref, + }, + [5] = { // GGML_TYPE_Q4_3 + .type_name = "DEPRECATED", +@@ -6446,6 +6448,7 @@ size_t ggml_quantize_chunk( + case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; ++ case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_F16: + { + size_t elemsize = sizeof(ggml_fp16_t); diff --git a/llama/patches/0024-cuda-disable-graph-compat-check-for-OP_ADD.patch b/llama/patches/0024-cuda-disable-graph-compat-check-for-OP_ADD.patch new file mode 100644 index 000000000..535b09eb1 --- /dev/null +++ b/llama/patches/0024-cuda-disable-graph-compat-check-for-OP_ADD.patch @@ -0,0 +1,34 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Thu, 31 Jul 2025 12:31:58 -0700 +Subject: [PATCH] cuda: disable graph compat check for OP_ADD + +--- + ggml/src/ggml-cuda/ggml-cuda.cu | 14 -------------- + 1 file changed, 14 deletions(-) + +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index bb19b06e..080e7467 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -2509,20 +2509,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud + #endif + } + +- // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n +- // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here +- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256 +- && node->ne[2] == 1 +- && node->ne[3] == 1 +- && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false +- && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) { +- // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. +- use_cuda_graph = false; +-#ifndef NDEBUG +- GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +-#endif +- } +- + if (node->op == GGML_OP_CPY) { + + // Store the pointers which are updated for each token, such that these can be sent diff --git a/llama/patches/0025-Disable-ggml-blas-on-macos-v13-and-older.patch b/llama/patches/0025-Disable-ggml-blas-on-macos-v13-and-older.patch new file mode 100644 index 000000000..465792600 --- /dev/null +++ b/llama/patches/0025-Disable-ggml-blas-on-macos-v13-and-older.patch @@ -0,0 +1,25 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Sun, 3 Aug 2025 10:00:20 -0700 +Subject: [PATCH] Disable ggml-blas on macos v13 and older + +--- + ggml/src/ggml-blas/ggml-blas.cpp | 5 +++++ + 1 file changed, 5 insertions(+) + +diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp +index ec158dfa..22926d75 100644 +--- a/ggml/src/ggml-blas/ggml-blas.cpp ++++ b/ggml/src/ggml-blas/ggml-blas.cpp +@@ -505,6 +505,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { + }; + + ggml_backend_reg_t ggml_backend_blas_reg(void) { ++ // MacOS prior to v14 does not include cblas_sgemm - disable this backend if it isn't available ++ if (&cblas_sgemm == NULL) { ++ GGML_LOG_INFO("Disabling ggml-blas backend on old MacOS version\n"); ++ return NULL; ++ } + static struct ggml_backend_reg ggml_backend_blas_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_blas_reg_i, diff --git a/ml/backend.go b/ml/backend.go index 06f9de9ae..fcb7db5ed 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -276,6 +276,7 @@ type Tensor interface { Cos(ctx Context) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor + QuickGELU(ctx Context) Tensor SILU(ctx Context) Tensor RELU(ctx Context) Tensor Sigmoid(ctx Context) Tensor @@ -283,7 +284,7 @@ type Tensor interface { Reshape(ctx Context, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor - Contiguous(ctx Context) Tensor + Contiguous(ctx Context, shape ...int) Tensor Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor Pad(ctx Context, shape ...int) Tensor @@ -468,4 +469,5 @@ const ( DTypeQ80 DTypeQ40 DTypeI32 + DTypeMXFP4 ) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 243476891..15c210dc1 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -239,10 +239,12 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor { for _, bt := range bts { if _, ok := ctxs[bt]; !ok { + // slog.Info("XXX before ggml_init") ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors), no_alloc: true, }) + // slog.Info("XXX after ggml_init") } targets[t.source.Name] = append(targets[t.source.Name], t.target) @@ -541,6 +543,8 @@ func (b *Backend) NewContextSize(n int) ml.Context { var allocatedBuffers []*C.struct_ggml_backend_buffer + // slog.Info("XXX before ggml_init") + // defer slog.Info("XXX after ggml_init") return &Context{ b: b, maxGraphNodes: n, @@ -708,6 +712,8 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { cdtype = C.GGML_TYPE_Q4_0 case ml.DTypeI32: cdtype = C.GGML_TYPE_I32 + case ml.DTypeMXFP4: + cdtype = C.GGML_TYPE_MXFP4 default: panic("unsupported dtype") } @@ -896,6 +902,8 @@ func (t *Tensor) DType() ml.DType { return ml.DTypeQ40 case C.GGML_TYPE_I32: return ml.DTypeI32 + case C.GGML_TYPE_MXFP4: + return ml.DTypeMXFP4 default: return ml.DTypeOther } @@ -958,10 +966,35 @@ func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { } } -func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_cont(ctx.(*Context).ctx, t.t), +func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor { + switch len(shape) { + case 0: + return &Tensor{ + b: t.b, + t: C.ggml_cont(ctx.(*Context).ctx, t.t), + } + case 1: + return &Tensor{ + b: t.b, + t: C.ggml_cont_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), + } + case 2: + return &Tensor{ + b: t.b, + t: C.ggml_cont_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), + } + case 3: + return &Tensor{ + b: t.b, + t: C.ggml_cont_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), + } + case 4: + return &Tensor{ + b: t.b, + t: C.ggml_cont_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), + } + default: + panic("unsupported number of dimensions") } } @@ -1176,11 +1209,18 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor { // Default options - opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}} + opts := rope.Options{ + Factors: &Tensor{}, + OriginalContextLength: 131072, + ExtrapolationFactor: 0., + AttentionFactor: 1., + BetaFast: 32., + BetaSlow: 1., + } // Apply any provided options for _, option := range options { - option(opts) + option(&opts) } dequant := t.t @@ -1200,10 +1240,10 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase C.int(opts.OriginalContextLength), C.float(ropeBase), C.float(ropeScale), - C.float(0.0), - C.float(1.0), - C.float(32.0), - C.float(1.0), + C.float(opts.ExtrapolationFactor), + C.float(opts.AttentionFactor), + C.float(opts.BetaFast), + C.float(opts.BetaSlow), ), } } @@ -1222,6 +1262,13 @@ func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { } } +func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -1350,3 +1397,65 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor { t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)), } } + +func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor { + // Unchecked to handle quantized types + t := c.newTensor(dtype, shape) + if len(s) > 0 { + C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + } + + return t +} + +// TODO - DRY this out with New if possible +func newTestBackend(size int) *Backend { + var cpus []*C.struct_ggml_backend_device + for _, d := range devices() { + switch C.ggml_backend_dev_type(d) { + case C.GGML_BACKEND_DEVICE_TYPE_CPU: + if len(cpus) == 0 { + // only the first cpu device should be used + cpus = append(cpus, d) + break + } + } + } + var schedBackends []*C.struct_ggml_backend + var schedBufts []*C.struct_ggml_backend_buffer_type + b := C.ggml_backend_dev_init(cpus[0], nil) + bt := C.ggml_backend_get_default_buffer_type(b) + C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(runtime.NumCPU()))) + // C.ggml_backend_cpu_set_n_threads(b, 1) // DEBUGGING + schedBackends = append(schedBackends, b) + schedBufts = append(schedBufts, bt) + return &Backend{ + meta: nil, + sched: C.ggml_backend_sched_new( + (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), + (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), + C.int(len(schedBackends)), + C.size_t(max(8192, size)), + false, + false, + ), + input: bt, + maxGraphNodes: max(8192, size), + schedBackends: schedBackends, + schedBufts: schedBufts, + } +} + +func newTestContext(b *Backend, n int) *Context { + n = max(8192, n) + // slog.Info("XXX before ggml_init") + // defer slog.Info("XXX after ggml_init") + return &Context{ + b: b, + maxGraphNodes: n, + ctx: C.ggml_init(C.struct_ggml_init_params{ + mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false), + no_alloc: true, + }), + } +} diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h index e91dedf14..873baa24f 100644 --- a/ml/backend/ggml/ggml/include/ggml.h +++ b/ml/backend/ggml/ggml/include/ggml.h @@ -353,7 +353,7 @@ extern "C" { GGML_TYPE_F16 = 1, GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_1 = 3, - // GGML_TYPE_Q4_2 = 4, support has been removed + GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2 // GGML_TYPE_Q4_3 = 5, support has been removed GGML_TYPE_Q5_0 = 6, GGML_TYPE_Q5_1 = 7, diff --git a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp index ec158dfac..22926d758 100644 --- a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp @@ -505,6 +505,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { }; ggml_backend_reg_t ggml_backend_blas_reg(void) { + // MacOS prior to v14 does not include cblas_sgemm - disable this backend if it isn't available + if (&cblas_sgemm == NULL) { + GGML_LOG_INFO("Disabling ggml-blas backend on old MacOS version\n"); + return NULL; + } static struct ggml_backend_reg ggml_backend_blas_reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, /* .iface = */ ggml_backend_blas_reg_i, diff --git a/ml/backend/ggml/ggml/src/ggml-common.h b/ml/backend/ggml/ggml/src/ggml-common.h index 086c822d7..e0d71451b 100644 --- a/ml/backend/ggml/ggml/src/ggml-common.h +++ b/ml/backend/ggml/ggml/src/ggml-common.h @@ -417,6 +417,13 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#define MXFP4 32 +typedef struct { + uint8_t d; // scale E8M0 float + uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.h index e33d9d473..6a25d0626 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.h @@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); + #ifdef __cplusplus } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c index 2462d2b85..bff9c426e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c @@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_MXFP4] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index 654e2f280..be0aa683b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_MXFP4: case GGML_TYPE_COUNT: { GGML_ABORT("fatal error"); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp index 02d406182..ec3ec9b17 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp @@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl } return sum = (ggml_float)logf(sum); } + +#define MXFP4 32 +typedef struct { + uint8_t d; // scale E8M0 float + uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); +#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0} + +void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) { + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + ggml_float mxfp4_table[] = MXFP4_VALS; + +#if defined(GGML_SIMD) + float sumf = 0.0f; + const int np = (n & ~(GGML_F32_STEP - 1)); + const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx; + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC scalev; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64 + for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4 + // convert GGML_F32_ARR X elements + const int ib = (i + j*GGML_F32_EPR) / MXFP4; + const block_mxfp4 * GGML_RESTRICT x = &xx[ib]; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + scalev = GGML_F32_VEC_SET1(scale.as_value); + float xf[GGML_F32_EPR]= {0.f}; + assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun"); + for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) { + xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)]; + xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4]; + } + + ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; i+=2) { + const int ib = i / MXFP4; + const block_mxfp4 * GGML_RESTRICT x = &xx[ib]; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)]; + sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4]; + } + + +#else // defined(GGML_SIMD) + const int nb = n / MXFP4; + assert(n % MXFP4 == 0); + + int yi = 0; + + const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx; + + ggml_float sumf = 0.0; + for (int ib = 0; ib < nb; ++ib) { + const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0]; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + for (int i = 0; i < MXFP4/2; ++i) { + sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]); + sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]); + } + } +#endif + + *s = sumf; +} diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h index 23cbb3051..7480ca089 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h @@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); +void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); + void ggml_vec_silu_f32(const int n, float * y, const float * x); ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu index c6dec4276..0e016ccc0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu @@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_xs<<>>(vx, y); } +// MXFP4 dequantize derived from dequantize_block_q4_0 +template +static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + const uint16_t dst_bias = 15; + const uint16_t dst_0p5 = 0x3800; + const uint16_t dst_m_bits = 10; + const int64_t i = blockIdx.x; + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + const uint64_t offset = 256*i + MXFP4*ir + 8*il; + dst_t * y = yy + offset; + + const block_mxfp4 * x = (const block_mxfp4 *)vx + ib; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + + // offset within the block 1/4 chunks (8 items) + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + uint16_t em0 = q[l] & 0x07; + uint16_t em1 = q[l] & 0x70; + // float16 values + iq1m_scale_t x0; + iq1m_scale_t x1; + + x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12); + x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0.u16 = dst_0p5 | (x0.u16 & 0x8000); + } + if (em1 == 0x10) { + x1.u16 = dst_0p5 | (x1.u16 & 0x8000); + } + // x is zero, do nothing + + // XXX it looks correct here - but mulmat still gives bad results... + // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n", + // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16)); + // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n", + // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16)); + + y[l*2] = scale.as_value * float(x0.f16); + y[l*2+1] = scale.as_value * float(x1.f16); + } +} + +// derived from dequantize_row_q4_0_cuda +template +static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + dequantize_block_mxfp4<<>>(vx, y, nb32); +} + template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return convert_unary_cont_cuda; case GGML_TYPE_BF16: return convert_unary_cont_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; default: return nullptr; } @@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return convert_unary_cont_cuda; case GGML_TYPE_BF16: return convert_unary_cont_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; default: return nullptr; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index 28ccf4bef..080e7467b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -21,6 +21,7 @@ #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmv.cuh" +#include "ggml-cuda/mmvmxfp4.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas( const int cc = ggml_cuda_info().devices[id].cc; - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4; if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); @@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE + && src0->type != GGML_TYPE_MXFP4; + bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4 + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; @@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); + } else if (use_mul_mat_vec_mxfp4) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr); } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } @@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) { + ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst); + return; + } if (ne2 == 1) { if (ggml_is_quantized(src0->type)) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); @@ -2498,20 +2509,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n - // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here - if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256 - && node->ne[2] == 1 - && node->ne[3] == 1 - && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false - && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) { - // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); -#endif - } - if (node->op == GGML_OP_CPY) { // Store the pointers which are updated for each token, such that these can be sent @@ -3056,6 +3053,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_BF16: + case GGML_TYPE_MXFP4: #ifdef GGML_USE_MUSA if (a->type == GGML_TYPE_Q3_K) { return false; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu new file mode 100644 index 000000000..da62062b3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu @@ -0,0 +1,307 @@ +#include "ggml.h" +#include "common.cuh" +#include "mmvmxfp4.cuh" + +// MXFP4 implementation derived from mmv.cu float32 code paths +typedef union { + half f16; + uint16_t u16; +} f16_t; + +template // TODO type_acc unused - consider bf16 support +static __global__ void mul_mat_vec_mxfp4( + const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { + const int64_t row = blockIdx.x; + const int64_t channel_dst = blockIdx.y; + const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int64_t sample_dst = blockIdx.z; + const int64_t sample_x = sample_dst / sample_ratio; + const int64_t sample_y = sample_dst; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const uint16_t dst_bias = 15; + const uint16_t dst_0p5 = 0x3800; + const uint16_t dst_m_bits = 10; + + x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += sample_y *stride_sample_y + channel_y *stride_channel_y; + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float) + float * buf_iw = (float *) data_mmv; + + if (block_size > warp_size) { + if (tid < warp_size) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + int offset0 = col2 / (MXFP4/2); + int i = col2 % (MXFP4/2); + const block_mxfp4 *x2 = x+offset0; + + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x2->d) << 23); + uint16_t em0 = x2->qs[i] & 0x07; + uint16_t em1 = x2->qs[i] & 0x70; + // float16 values + f16_t x0; + f16_t x1; + x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12); + x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0.u16 = dst_0p5 | (x0.u16 & 0x8000); + } + if (em1 == 0x10) { + x1.u16 = dst_0p5 | (x1.u16 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale.as_value)) { + sumf = scale.as_value; + break; + } + + const float2 tmpx = {x0.f16, x1.f16}; + const float2 tmpy = y2[col2]; + sumf += tmpx.x*tmpy.x*scale.as_value; + sumf += tmpx.y*tmpy.y*scale.as_value; + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf; + __syncthreads(); + if (tid >= warp_size) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda_mxfp4( + const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + // GGML_ASSERT(stride_row % 2 == 0); // TODO + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + int device; + int warp_size; + + CUDA_CHECK(cudaGetDevice(&device)); + warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = warp_size*sizeof(float); + const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_dims(block_size_best, 1, 1); + + switch (block_size_best) { + case 32: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 64: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 96: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 128: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 160: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 192: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 224: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 256: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void mul_mat_vec_cuda_mxfp4( + const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + enum ggml_prec prec, cudaStream_t stream) { + launch_mul_mat_vec_cuda_mxfp4 + (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); +} + +void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(ne13 == ne3); + + // GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + const int64_t stride_row = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t stride_channel_x = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t stride_sample_x = src0->nb[3] / ts_src0; + const int64_t stride_sample_y = src1->nb[3] / ts_src1; + const int64_t stride_sample_dst = dst->nb[3] / ts_dst; + const int64_t nsamples_dst = ne3; + const int64_t nsamples_x = ne03; + const int64_t nchannels_x = ne02; + const int64_t nrows = ne01; + const int64_t ncols = ne00; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + GGML_ASSERT(ncols_dst == 1); + + const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data; + mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream()); +} + +void ggml_cuda_op_mul_mat_vec_mxfp4( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00 / MXFP4; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t nchannels_dst = 1; + const int64_t stride_channel_x = 0; + const int64_t stride_channel_y = 0; + const int64_t stride_channel_dst = 0; + const int64_t nsamples_x = 1; + const int64_t nsamples_dst = 1; + const int64_t stride_sample_x = 0; + const int64_t stride_sample_y = 0; + const int64_t stride_sample_dst = 0; + + const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i; + mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cuh new file mode 100644 index 000000000..a08fc7800 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cuh @@ -0,0 +1,9 @@ +#include "common.cuh" + +void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + +void ggml_cuda_op_mul_mat_vec_mxfp4( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 8f9a25e6f..5eba1dafc 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -421,6 +421,13 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#define MXFP4 32 +typedef struct { + uint8_t d; // scale E8M0 float + uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL @@ -1929,6 +1936,9 @@ GGML_TABLE_END() #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +#define N_R0_MXFP4 4 +#define N_SG_MXFP4 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -4380,16 +4390,16 @@ void mul_vec_q_n_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - ushort tiisg, - ushort sgitg) { - const int nb = args.ne00/QK4_0; + uint3 tgpig, // Threadgroup Position in Grid + ushort tiisg, // Thread Index in SIMD Group + ushort sgitg) { // SIMD Group Index in ThreadGroup + const int nb = args.ne00/QK4_0; // src0->ne[0] / 32 const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4 const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9222,6 +9232,49 @@ kernel void kernel_mul_mm_id( } } +template +void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { + float4x4 reg_f; + const ushort dst_bias = 15; + const ushort dst_0p5 = 0x3800; + const ushort dst_m_bits = 10; + const half scale = (half)(as_type(((uint32_t)xb->d) << 23)); + // il:0 first 16, il:1 last 16 + for (int i = 0; i < 8; i++) { + ushort em0 = xb->qs[il*8 + i] & 0x07; + ushort em1 = xb->qs[il*8 + i] & 0x70; + // float16 values + ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12); + ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale)) { + reg_f[i/2][2*(i%2) + 0] = scale; + reg_f[i/2][2*(i%2) + 1] = scale; + } else { + reg_f[i/2][2*(i%2) + 0] = scale * as_type(x0); + reg_f[i/2][2*(i%2) + 1] = scale * as_type(x1); + } + } + reg = (type4x4) reg_f; +} + #define QK_NL 16 // @@ -9289,6 +9342,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; + // // indirect matrix-matrix multiplication // @@ -9320,6 +9375,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; + // // matrix-vector multiplication @@ -9436,6 +9493,120 @@ kernel void kernel_mul_mv_id( sgitg); } +// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y +void mul_mv_mxfp4_f32_impl( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort dst_bias = 15; + const ushort dst_0p5 = 0x3800; + const ushort dst_m_bits = 10; + const int nr0 = N_R0_MXFP4; + const int nsg = N_SG_MXFP4; + const int nw = N_SIMDWIDTH; + const int nb = args.ne00/MXFP4; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_mxfp4 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*16; + + device const float * yb = y + ix*MXFP4 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + +#pragma unroll + for (short row = 0; row < nr0; row++) { + // Processes 16 items + device const block_mxfp4 * qb_curr = ax[row] + ib; + float d = as_type(((uint32_t)(ax[row] + ib)->d) << 23); + // il = 0 or 16 + device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2); + for (int i = 0; i < 8; ++i) { + ushort em0 = qs[i] & 0x07; + ushort em1 = qs[i] & 0x70; + ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12); + ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8); + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + if (!isnan(d)) { + sumf[row] += yb[i*2] * as_type(x0) * d + + yb[i*2+1] * as_type(x1) * d; + } else { + sumf[row] = d; + } + } + } + + yb += MXFP4 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_mxfp4_f32")]] +kernel void kernel_mul_mv_mxfp4_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; @@ -9465,6 +9636,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + kernel void kernel_pool_2d_max_f32( device const float * src0, device float * dst, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976f..938386ba8 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h @@ -65,6 +65,9 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +#define N_R0_MXFP4 4 +#define N_SG_MXFP4 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index ab46f6e3a..d8e05a21b 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static struct ggml_backend_reg g_ggml_backend_metal_reg; static struct ggml_backend_device g_ggml_backend_metal_device; + // information about a Metal device // note: assumes single GPU device - the default one // TODO: support multiple GPU devices @@ -209,6 +210,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, @@ -288,6 +290,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, @@ -310,6 +313,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, @@ -334,6 +338,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, @@ -934,7 +939,7 @@ static id ggml_metal_load_library(id device, bool use_bfl MTLCompileOptions * options = [MTLCompileOptions new]; options.preprocessorMacros = prep; - + //[options setFastMathEnabled:false]; metal_library = [device newLibraryWithSource:src options:options error:&error]; @@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); @@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); @@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); @@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); @@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; default: GGML_ABORT("MUL MAT-MAT not implemented"); } @@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node( smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; + } break; default: { GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break; default: GGML_ABORT("MUL_MAT_ID not implemented"); } @@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node( smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; + } break; default: { GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index 08e8d8070..69fa17de3 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - ushort tiisg, - ushort sgitg) { - const int nb = args.ne00/QK4_0; + uint3 tgpig, // Threadgroup Position in Grid + ushort tiisg, // Thread Index in SIMD Group + ushort sgitg) { // SIMD Group Index in ThreadGroup + const int nb = args.ne00/QK4_0; // src0->ne[0] / 32 const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4 const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id( } } +template +void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { + float4x4 reg_f; + const ushort dst_bias = 15; + const ushort dst_0p5 = 0x3800; + const ushort dst_m_bits = 10; + const half scale = (half)(as_type(((uint32_t)xb->d) << 23)); + // il:0 first 16, il:1 last 16 + for (int i = 0; i < 8; i++) { + ushort em0 = xb->qs[il*8 + i] & 0x07; + ushort em1 = xb->qs[il*8 + i] & 0x70; + // float16 values + ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12); + ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale)) { + reg_f[i/2][2*(i%2) + 0] = scale; + reg_f[i/2][2*(i%2) + 1] = scale; + } else { + reg_f[i/2][2*(i%2) + 0] = scale * as_type(x0); + reg_f[i/2][2*(i%2) + 1] = scale * as_type(x1); + } + } + reg = (type4x4) reg_f; +} + #define QK_NL 16 // @@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; + // // indirect matrix-matrix multiplication // @@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; + // // matrix-vector multiplication @@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id( sgitg); } +// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y +void mul_mv_mxfp4_f32_impl( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort dst_bias = 15; + const ushort dst_0p5 = 0x3800; + const ushort dst_m_bits = 10; + const int nr0 = N_R0_MXFP4; + const int nsg = N_SG_MXFP4; + const int nw = N_SIMDWIDTH; + const int nb = args.ne00/MXFP4; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_mxfp4 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*16; + + device const float * yb = y + ix*MXFP4 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + +#pragma unroll + for (short row = 0; row < nr0; row++) { + // Processes 16 items + device const block_mxfp4 * qb_curr = ax[row] + ib; + float d = as_type(((uint32_t)(ax[row] + ib)->d) << 23); + // il = 0 or 16 + device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2); + for (int i = 0; i < 8; ++i) { + ushort em0 = qs[i] & 0x07; + ushort em1 = qs[i] & 0x70; + ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12); + ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8); + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + if (!isnan(d)) { + sumf[row] += yb[i*2] * as_type(x0) * d + + yb[i*2+1] * as_type(x1) * d; + } else { + sumf[row] = d; + } + } + } + + yb += MXFP4 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_mxfp4_f32")]] +kernel void kernel_mul_mv_mxfp4_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; @@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + kernel void kernel_pool_2d_max_f32( device const float * src0, device float * dst, diff --git a/ml/backend/ggml/ggml/src/ggml-quants.c b/ml/backend/ggml/ggml/src/ggml-quants.c index 84ec6dfe3..17c308aae 100644 --- a/ml/backend/ggml/ggml/src/ggml-quants.c +++ b/ml/backend/ggml/ggml/src/ggml-quants.c @@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE quantize_iq2_s(x, y, 1, k, NULL); } +// =============================== mxfp4 (de)-quantization + +void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { + static const int qk = MXFP4; + static const uint32_t E8_BIAS = 127; + static const uint32_t E2_BIAS = 1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + } + } + + const float dequant_scale = amax / 6.0f; + uint32_t dequant_scale_exponent = 0; + memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent)); + + // Rounding up + dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000; + // Rounding down + // dequant_scale_exponent = dequant_scale_exponent & 0x7F800000; + + float dequant_scale_rounded = 0.0f; + memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded)); + float quant_scale = 0.0f; + if (dequant_scale_rounded != 0.0f) { + quant_scale = 1.0f / dequant_scale_rounded; + } + + y[i].d = (uint8_t)(dequant_scale_exponent >> 23); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + j*2]*quant_scale; + const float x1 = x[i*qk + j*2+1]*quant_scale; + + uint32_t xi0 = 0; + uint32_t xi1 = 0; + memcpy(&xi0, &x0, sizeof(xi0)); + memcpy(&xi1, &x1, sizeof(xi1)); + + uint32_t s0 = xi0 & 0x80000000; + uint32_t s1 = xi1 & 0x80000000; + uint32_t e0 = (xi0 >> 23) & 0xFF; + uint32_t e1 = (xi1 >> 23) & 0xFF; + uint32_t m0 = (xi0 & 0x7FFFFF); + uint32_t m1 = (xi1 & 0x7FFFFF); + + // 0.25 <= x < 0.75 maps to 0.5, a denormal number + // Move implicit bit 1 at the beginning to mantissa for denormals + // adjusted_exponents + uint32_t ae0 = E8_BIAS - (e0 + 1); + uint32_t ae1 = E8_BIAS - (e1 + 1); + if (e0 < E8_BIAS) { + m0 = (0x400000 | (m0 >> 1)) >> ae0; + } + if (e1 < E8_BIAS) { + m1 = (0x400000 | (m1 >> 1)) >> ae1; + } + + // For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS); + e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS); + + // Combine sign, exponent, and mantissa, while saturating + // rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7); + uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7); + uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0); + uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1); + y[i].qs[j] = v0; + y[i].qs[j] |= v1 << 4; + } + } +} + +void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % MXFP4 == 0); + + const int nb = k / MXFP4; + const uint16_t dst_bias = 15; + const uint16_t dst_0p5 = 0x3800; + const uint16_t dst_m_bits = 10; + + for (int i = 0; i < nb; i++) { + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x[i].d) << 23); + for (int j = 0; j < MXFP4/2; ++j) { + uint16_t em0 = x[i].qs[j] & 0x07; + uint16_t em1 = x[i].qs[j] & 0x70; + // float16 values + uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12); + uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale.as_value)) { + y[i*MXFP4 + j*2] = scale.as_value; + y[i*MXFP4 + j*2+1] = scale.as_value; + } else { + y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value; + y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value; + } + } + } +} + + +size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); +} + // =============================== data validation static bool validate_float(float f, size_t i) { @@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - + case GGML_TYPE_MXFP4: + // TODO - anything to validate? + break; case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ml/backend/ggml/ggml/src/ggml-quants.h b/ml/backend/ggml/ggml/src/ggml-quants.h index d09173e11..2fc40f754 100644 --- a/ml/backend/ggml/ggml/src/ggml-quants.h +++ b/ml/backend/ggml/ggml/src/ggml-quants.h @@ -37,6 +37,8 @@ GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); + // Dequantization GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -65,6 +67,8 @@ GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, floa GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + GGML_API void iq2xs_init_impl(enum ggml_type type); GGML_API void iq2xs_free_impl(enum ggml_type type); GGML_API void iq3xs_init_impl(int grid_size); diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c index 8a6546240..0f3c98340 100644 --- a/ml/backend/ggml/ggml/src/ggml.c +++ b/ml/backend/ggml/ggml/src/ggml.c @@ -589,11 +589,13 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_q4_1, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, }, - [4] = { // GGML_TYPE_Q4_2 - .type_name = "DEPRECATED", - .blck_size = 0, - .type_size = 0, - .is_quantized = false, + [GGML_TYPE_MXFP4] = { // formerly deprecated GGML_TYPE_Q4_2 + .type_name = "mxfp4", + .blck_size = MXFP4, + .type_size = sizeof(block_mxfp4), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp4, + .from_float_ref = (ggml_from_float_t) quantize_row_mxfp4_ref, }, [5] = { // GGML_TYPE_Q4_3 .type_name = "DEPRECATED", @@ -6446,6 +6448,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); diff --git a/ml/backend/ggml/ggml_test.go b/ml/backend/ggml/ggml_test.go new file mode 100644 index 000000000..70ebb9df4 --- /dev/null +++ b/ml/backend/ggml/ggml_test.go @@ -0,0 +1,60 @@ +package ggml + +import ( + "bytes" + "log/slog" + "os" + "slices" + "testing" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" +) + +func TestMain(m *testing.M) { + slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel())) + os.Exit(m.Run()) +} + +func setup(tb testing.TB) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := ggml.WriteGGUF(f, ggml.KV{ + "general.architecture": "test", + "test.block_count": uint32(1), + }, []*ggml.Tensor{ + {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 4))}, + }); err != nil { + tb.Fatal(err) + } + + b, err := New(f.Name(), ml.BackendParams{NumGPULayers: 1}) + if err != nil { + tb.Fatal(err) + } + + return b +} + +// initContextOrSkip takes a testing.T and true for GPU +// If GPUs are not available, the current test is skipped +// gpu=false will always succed +func initContextOrSkip(t *testing.T, b ml.Backend, gpu bool) ml.Context { + if gpu && len(b.(*Backend).schedBackends) == 1 { + t.Skip("No GPU detected, skipping GPU test case") + } + ctx := b.NewContext() + t.Cleanup(func() { ctx.Close() }) + if gpu { + return ctx.Layer(0) + } + return ctx.Input() +} diff --git a/ml/backend/ggml/mxfp4_test.go b/ml/backend/ggml/mxfp4_test.go new file mode 100644 index 000000000..3c17eb8aa --- /dev/null +++ b/ml/backend/ggml/mxfp4_test.go @@ -0,0 +1,795 @@ +package ggml + +import ( + "math" + "math/rand" + "os" + "testing" + + "github.com/ollama/ollama/ml" + + fsggml "github.com/ollama/ollama/fs/ggml" +) + +/* + To get GPUs loading in these tests on windows... + + $env:OLLAMA_LIBRARY_PATH="$(pwd)\build\lib\ollama" + $env:PATH="$(pwd)\build\lib\ollama;$env:PATH" + + go test .\ml\backend\ggml\... -run TestMXFP4 +*/ + +// MXFP4 reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + +// E2M1 values +var mxfp4_vals = []float32{ + 0.0, // 0 00 0 = 0x0 + 0.5, // 0 00 1 = 0x1 + 1.0, // 0 01 0 = 0x2 + 1.5, // 0 01 1 = 0x3 + 2.0, // 0 10 0 = 0x4 + 3.0, // 0 10 1 = 0x5 + 4.0, // 0 11 0 = 0x6 + 6.0, // 0 11 1 = 0x7 + 0.0, // 1 00 0 = 0x8 + -0.5, // 1 00 1 = 0x9 + -1.0, // 1 01 0 = 0xa + -1.5, // 1 01 1 = 0xb + -2.0, // 1 10 0 = 0xc + -3.0, // 1 10 1 = 0xd + -4.0, // 1 11 0 = 0xe + -6.0, // 1 11 1 = 0xf +} + +func TestMXFP4Ops(t *testing.T) { + b := setup(t) + for _, useGPU := range []bool{false, true} { + useGPU := useGPU + var label string + if useGPU { + label = "gpu" + } else { + label = "cpu" + } + t.Run(label, func(t *testing.T) { + t.Run("mulmatid", func(t *testing.T) { + // Use exact values that are supported without scaling so we can compare against an fp32 tensor + t.Run("exact", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s00 = 64 + const s01 = 1 + const s02 = 2 + const s10 = s00 + const s11 = 1 + const s12 = 1 + // const s00 = 2880 + // const s01 = 5760 + // const s02 = 32 + // const s10 = s00 + // const s11 = 1 + // const s12 = 64 + + data := [s00 * s01 * s02]float32{} + for i := range data { + data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s00, s01, s02) + t1f := ctx.(*Context).FromFloatSlice(data[:], s00, s01, s02) + // for i := range len(data) / 32 { // MXFP4 block size + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", data[i*32+j]) + // } + // t.Logf(" t1[%s]\n", strings.Join(vals[:], ", ")) + // } + + // random 0-1 float + d2 := [s10 * s11 * s12]float32{} + for i := range d2 { + d2[i] = float32(r.Float32()) + } + // for i := range len(d2) / s10 { + // vals := [s10]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", d2[i*s10+j]) + // } + // t.Logf(" t2[%s]\n", strings.Join(vals[:], ", ")) + // } + t2 := ctx.(*Context).FromFloatSlice(d2[:], s10, s11, s12) + + d3 := [4 * s12]int32{} + for i := range d3 { + d3[i] = int32(i) % s02 + } + t3 := ctx.(*Context).FromIntSlice(d3[:], 4, s12) + + // t.Log("calling MulmatID") + t4 := t1.MulmatID(ctx, t2, t3) + t4f := t1f.MulmatID(ctx, t2, t3) + d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(2)) // lower precision for CPU accuracy + d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(2)) + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + // t.Logf("MulmatID results matched:\n%s", d4) + }) + + t.Run("range", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s0 = 64 + const s1 = 2 + const s2 = 4 + const idlen = 4 + data := [s0 * s1 * s2]float32{} + inTotal := float32(0) + for i := range data { + data[i] = float32(i) + inTotal += float32(i) + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + // Reconvert back to floats to remove the quantization fidelity loss for comparison + dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1, s2) + t1f := ctx.(*Context).FromFloatSlice(dataf, s0, s1, s2) + // for i := range len(data) / 32 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", dataf[i*32+j]) + // } + // t.Logf(" t1[%s]\n", strings.Join(vals[:], ", ")) + // } + + d2 := [s0]float32{} + for i := range d2 { + // d2[i] = float32(i) + d2[i] = float32(r.Float32()) + } + // for i := range len(d2) / s0 { + // vals := [s0]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", d2[i*s0+j]) + // } + // t.Logf(" t2[%s]\n", strings.Join(vals[:], ", ")) + // } + t2 := ctx.(*Context).FromFloatSlice(d2[:], s0) + + // TODO - there might be a CUDA bug here... + d3 := [idlen]int32{1, 1, 2, 3} + // for i := range d3 { + // d3[i] = int32(i) % s2 + // t.Logf("%d] %d", i, d3[i]) + // } + t3 := ctx.(*Context).FromIntSlice(d3[:], idlen) + + // t.Log("calling Mulmat") + t4 := t1.MulmatID(ctx, t2, t3) + t4f := t1f.MulmatID(ctx, t2, t3) + // Metal has some drift so use reduced precision for dump comparisons + d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(2)) + d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(2)) + r4 := t4.Floats() + r4f := t4f.Floats() + sim := cosineSimilarity(r4, r4f) + if sim < 0.99 { + t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + // t.Logf("mxfp4 result\n%s", d4) + }) + t.Run("random", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s00 = 2880 + const s01 = 5760 + const s02 = 32 + const s10 = s00 + const s11 = 1 + const s12 = 64 + const idlen = 4 + + data := [s00 * s01 * s02]float32{} + for i := range data { + data[i] = float32(r.Float32() * 10.0) + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + // Reconvert back to floats to remove the quantization fidelity loss for comparison + dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s00, s01, s02) + t1f := ctx.(*Context).FromFloatSlice(dataf, s00, s01, s02) + // for i := range len(data) / 32 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", dataf[i*32+j]) + // } + // t.Logf(" t1[%s]\n", strings.Join(vals[:], ", ")) + // } + + d2 := [s10 * s11 * s12]float32{} + for i := range d2 { + // d2[i] = float32(i) + d2[i] = float32(r.Float32()) + } + // for i := range len(d2) / s0 { + // vals := [s0]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", d2[i*s0+j]) + // } + // t.Logf(" t2[%s]\n", strings.Join(vals[:], ", ")) + // } + t2 := ctx.(*Context).FromFloatSlice(d2[:], s10, s11, s12) + + // arange equiv + d3 := [idlen * s12]int32{} + for i := range d3 { + d3[i] = int32(i) % s02 + } + t3 := ctx.(*Context).FromIntSlice(d3[:], idlen, s12) + + // t.Log("calling Mulmat") + // t3 := t1.Mulmat(ctx, t2) + // t3f := t1f.Mulmat(ctx, t2) + t4 := t1.MulmatID(ctx, t2, t3) + t4f := t1f.MulmatID(ctx, t2, t3) + // Metal and CPU have some drift so use reduced precision for dump comparisons + d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(1)) + d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(1)) + // t.Logf("mxfp4 data: \n%s", d4) + r4 := t4.Floats() + r4f := t4f.Floats() + sim := cosineSimilarity(r4, r4f) + if sim < 0.99 { + t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + }) + + // Use data file(s) with real data + t.Run("example_7", func(t *testing.T) { + ctx := initContextOrSkip(t, b, useGPU) + data0, err := os.ReadFile("mlp-gateup.bin") + if err != nil { + t.Skip("missing mlp-gateup.bin file, skipping test") + } + data1, err := os.ReadFile("hidden-states-7.bin") + if err != nil { + t.Skip("missing hidden-states.bin file, skipping test") + } + data2, err := os.ReadFile("selected-experts-7.bin") + if err != nil { + t.Skip("missing selected-experts.bin file, skipping test") + } + + dtype := ml.DTypeMXFP4 + data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32) + t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32) + t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32) + + // t.Logf("f32: \n%s", ml.Dump(ctx, t1f)) + + t2 := ctx.(*Context).FromBytes(ml.DTypeF32, data1, 2880, 1, 7) + // t.Logf("hidden-state: \n%s", ml.Dump(ctx, t2)) + + t3 := ctx.(*Context).FromBytes(ml.DTypeI32, data2, 4, 7) + // t.Logf("experts: \n%s", ml.Dump(ctx, t3)) + + // t.Log("calling MulmatID") + t4 := t1.MulmatID(ctx, t2, t3) + t4f := t1f.MulmatID(ctx, t2, t3) + + d4 := ml.Dump(ctx, t4) + d4f := ml.Dump(ctx, t4f) + + r4 := t4.Floats() + r4f := t4f.Floats() + sim := cosineSimilarity(r4, r4f) + if sim < 0.99 { + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + // t.Logf("MulmatID results matched:\n%s", d4) + }) + + // Use data file(s) with real data + t.Run("example_384", func(t *testing.T) { + ctx := initContextOrSkip(t, b, useGPU) + data0, err := os.ReadFile("mlp-gateup.bin") + if err != nil { + t.Skip("missing mlp-gateup.bin file, skipping test") + } + data1, err := os.ReadFile("hidden-states-384.bin") + if err != nil { + t.Skip("missing hidden-states.bin file, skipping test") + } + data2, err := os.ReadFile("selected-experts-384.bin") + if err != nil { + t.Skip("missing selected-experts.bin file, skipping test") + } + + dtype := ml.DTypeMXFP4 + data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32) + t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32) + t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32) + + // t.Logf("f32: \n%s", ml.Dump(ctx, t1f)) + + t2 := ctx.(*Context).FromBytes(ml.DTypeF32, data1, 2880, 1, 384) + // t.Logf("hidden-state: \n%s", ml.Dump(ctx, t2)) + + t3 := ctx.(*Context).FromBytes(ml.DTypeI32, data2, 4, 384) + // t.Logf("experts: \n%s", ml.Dump(ctx, t3)) + + // t.Log("calling MulmatID") + t4 := t1.MulmatID(ctx, t2, t3) + t4f := t1f.MulmatID(ctx, t2, t3) + + d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(3)) + d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(3)) + + r4 := t4.Floats() + r4f := t4f.Floats() + sim := cosineSimilarity(r4, r4f) + if sim < 0.99 { + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + // t.Logf("MulmatID results matched:\n%s", d4) + }) + + // Use data file(s) with real data + t.Run("example_1d", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + data0, err := os.ReadFile("mlp-gateup.bin") + if err != nil { + t.Skip("missing mlp-gateup.bin file, skipping test") + } + + dtype := ml.DTypeMXFP4 + data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32) + t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32) + t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32) + + // t.Logf("f32: \n%s", ml.Dump(ctx, t1f)) + data1 := [2880]float32{} + for i := range data1 { + data1[i] = float32(r.Float32()) + } + + t2 := ctx.(*Context).FromFloatSlice(data1[:], 2880) + // t.Logf("hidden-state: \n%s", ml.Dump(ctx, t2)) + data2 := [4]int32{ + 12, 30, 17, 7, + // 7, 17, 12, 30, + } + + t3 := ctx.(*Context).FromIntSlice(data2[:], 4) + // t.Logf("experts: \n%s", ml.Dump(ctx, t3)) + + // t.Log("calling MulmatID") + t4 := t1.MulmatID(ctx, t2, t3) + t4f := t1f.MulmatID(ctx, t2, t3) + + d4 := ml.Dump(ctx, t4) + d4f := ml.Dump(ctx, t4f) + + r4 := t4.Floats() + r4f := t4f.Floats() + sim := cosineSimilarity(r4, r4f) + if sim < 0.99 { + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + // t.Logf("MulmatID results matched:\n%s", d4) + }) + }) + + t.Run("mm", func(t *testing.T) { + t.Run("example", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + data0, err := os.ReadFile("mlp-gateup.bin") + if err != nil { + t.Skip("missing mlp-gateup.bin file, skipping test") + } + data1 := [2880 * 1 * 32]float32{} + for i := range data1 { + data1[i] = float32(r.Float32()) + } + + dtype := ml.DTypeMXFP4 + data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32) + t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32) + t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32) + + // t.Logf("f32: \n%s", ml.Dump(ctx, t1f)) + + t2 := ctx.(*Context).FromFloatSlice(data1[:], 2880, 1, 32) + + t4 := t1.Mulmat(ctx, t2) + t4f := t1f.Mulmat(ctx, t2) + + d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(3)) + d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(3)) + + r4 := t4.Floats() + r4f := t4f.Floats() + sim := cosineSimilarity(r4, r4f) + if sim < 0.99 { + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + + if d4 != d4f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4) + } + // t.Logf("Mulmat results matched:\n%s", d4) + }) + + t.Run("exact/3x3", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s10 = 64 + const s11 = 1 + const s12 = 2 + const s20 = s10 + const s21 = 1 + const s22 = 2 + + data := [s10 * s11 * s12]float32{} + for i := range data { + data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] + } + // for i := range len(data) / 32 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", data[i*32+j]) + // } + // t.Logf(" [%s]\n", strings.Join(vals[:], ", ")) + // } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + // for i := range len(mxData) / 17 { + // vals := [17]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2x", mxData[i*17+j]) + // } + // t.Logf(" %s\n", strings.Join(vals[:], ", ")) + // } + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s10, s11, s12) + t1f := ctx.(*Context).FromFloatSlice(data[:], s10, s11, s12) + + d2 := [s20 * s21 * s22]float32{} + for i := range d2 { + d2[i] = float32(r.Float32()) + } + t2 := ctx.(*Context).FromFloatSlice(d2[:], s20, s21, s22) + + t3f := t1f.Mulmat(ctx, t2) + t3 := t1.Mulmat(ctx, t2) + d3 := ml.Dump(ctx, t3) + d3f := ml.Dump(ctx, t3f) + if d3 != d3f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + } + }) + + t.Run("exact/2x2", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s0 = 32 + const s1 = 64 + + data := [s0 * s1]float32{} + for i := range data { + data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] + } + // for i := range 4 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", data[i*32+j]) + // } + // t.Logf(" [%s]\n", strings.Join(vals[:], ", ")) + // } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + // for i := range len(mxData) / 17 { + // vals := [17]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2x", mxData[i*17+j]) + // } + // t.Logf(" %s\n", strings.Join(vals[:], ", ")) + // } + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1) + t1f := ctx.(*Context).FromFloatSlice(data[:], s0, s1) + + d2 := [s0 * s1]float32{} + for i := range d2 { + d2[i] = float32(r.Float32()) + } + t2 := ctx.(*Context).FromFloatSlice(d2[:], s0, s1) + + t3f := t1f.Mulmat(ctx, t2) + t3 := t1.Mulmat(ctx, t2) + d3 := ml.Dump(ctx, t3) + d3f := ml.Dump(ctx, t3f) + if d3 != d3f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + } + }) + t.Run("exact/2x1", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s0 = 64 + const s1 = 4 + + data := [s0 * s1]float32{} + for i := range data { + data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] + } + // for i := range len(data) / 32 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", data[i*32+j]) + // } + // t.Logf(" t1[%s]\n", strings.Join(vals[:], ", ")) + // } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + // for i := range len(mxData) / 17 { + // vals := [17]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2x", mxData[i*17+j]) + // } + // t.Logf(" %s\n", strings.Join(vals[:], ", ")) + // } + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1) + t1f := ctx.(*Context).FromFloatSlice(data[:], s0, s1) + + d2 := [s0]float32{} + for i := range d2 { + d2[i] = float32(r.Float32()) + } + // for i := range len(d2) / 32 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", d2[i*32+j]) + // } + // t.Logf(" t2[%s]\n", strings.Join(vals[:], ", ")) + // } + + t2 := ctx.(*Context).FromFloatSlice(d2[:], s0) + + t3f := t1f.Mulmat(ctx, t2) + t3 := t1.Mulmat(ctx, t2) + d3 := ml.Dump(ctx, t3, ml.DumpWithPrecision(3)) + d3f := ml.Dump(ctx, t3f, ml.DumpWithPrecision(3)) + if d3 != d3f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + } + }) + + t.Run("range/2d", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + ctx := initContextOrSkip(t, b, useGPU) + const s0 = 32 + const s1 = 4 + data := [s0 * s1]float32{} + inTotal := float32(0) + for i := range data { + data[i] = float32(i) + inTotal += float32(i) + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + // Reconvert back to floats to remove the quantization fidelity loss for comparison + dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + dtype := ml.DTypeMXFP4 + t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1) + t1f := ctx.(*Context).FromFloatSlice(dataf, s0, s1) + // for i := range len(data) / 32 { + // vals := [32]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", dataf[i*32+j]) + // } + // t.Logf(" t1[%s]\n", strings.Join(vals[:], ", ")) + // } + + d2 := [s0 * s1]float32{} + for i := range d2 { + // d2[i] = float32(i) + d2[i] = float32(r.Float32()) + } + // for i := range len(d2) / s0 { + // vals := [s0]string{} + // for j := range vals { + // vals[j] = fmt.Sprintf("%0.2f", d2[i*s0+j]) + // } + // t.Logf(" t2[%s]\n", strings.Join(vals[:], ", ")) + // } + + t2 := ctx.(*Context).FromFloatSlice(d2[:], s0, s1) + + // t.Log("calling Mulmat") + t3 := t1.Mulmat(ctx, t2) + t3f := t1f.Mulmat(ctx, t2) + d3 := ml.Dump(ctx, t3, ml.DumpWithPrecision(2)) + d3f := ml.Dump(ctx, t3f, ml.DumpWithPrecision(2)) + r3 := t3.Floats() + r3f := t3f.Floats() + sim := cosineSimilarity(r3, r3f) + if sim < 0.99 { + t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + if d3 != d3f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + } + }) + + t.Run("range/3d", func(t *testing.T) { + ctx := initContextOrSkip(t, b, useGPU) + data := [32 * 4 * 2]float32{} + inTotal := float32(0) + for i := range data { + data[i] = float32(i) + inTotal += float32(i) + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + dtype := ml.DTypeMXFP4 + // Reconvert back to floats to remove the quantization fidelity loss for comparison + dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + t1 := ctx.(*Context).FromBytes(dtype, mxData, 32, 4, 2) + t1f := ctx.(*Context).FromFloatSlice(dataf, 32, 4, 2) + + d2 := [32 * 4 * 2]float32{} + for i := range d2 { + d2[i] = 2.0 + } + t2 := ctx.(*Context).FromFloatSlice(d2[:], 32, 4, 2) + + // t.Log("calling Mulmat") + t3 := t1.Mulmat(ctx, t2) + t3f := t1f.Mulmat(ctx, t2) + d3 := ml.Dump(ctx, t3) + d3f := ml.Dump(ctx, t3f) + r3 := t3.Floats() + r3f := t3f.Floats() + sim := cosineSimilarity(r3, r3f) + if sim < 0.99 { + t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + if d3 != d3f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + } + }) + }) + }) + } +} + +func TestMXFP4Simple(t *testing.T) { + b := setup(t) + + t.Run("fixed", func(t *testing.T) { + ctx := initContextOrSkip(t, b, false) + data := [32 * 2]float32{ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + dtype := ml.DTypeMXFP4 + // Reconvert back to floats to remove the quantization fidelity loss for comparison + dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + t1 := ctx.(*Context).FromBytes(dtype, mxData, 32, 2) + t1f := ctx.(*Context).FromFloatSlice(dataf, 32, 2) + + d2 := [32 * 2]float32{ + // 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + } + t2 := ctx.(*Context).FromFloatSlice(d2[:], 32, 2) + + t.Log("calling Mulmat") + t3f := t1f.Mulmat(ctx, t2) + t3 := t1.Mulmat(ctx, t2) + d3 := ml.Dump(ctx, t3) + d3f := ml.Dump(ctx, t3f) + if d3 != d3f { + t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3) + } + t.Logf("result (mxfp4): \n%s", d3) + }) +} + +func TestMXFP4Conversion(t *testing.T) { + t.Run("quantize/exact", func(t *testing.T) { + r := rand.New(rand.NewSource(0)) + + data := [32 * 4]float32{} + for i := range data { + data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + newData := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + + if len(data) != len(newData) { + t.Fatalf("length mismatch. started with %d but got %d", len(data), len(newData)) + } + for i := range data { + if data[i] != newData[i] { + t.Logf("started with: %v", data) + t.Logf("got : %v", newData) + t.Fatalf("mismatched data starting at offset %d started with %f but got %f", i, data[i], newData[i]) + } + } + }) + t.Run("quantize/arange", func(t *testing.T) { + data := [32 * 8]float32{} + for i := range data { + data[i] = float32(i) // / float32(6.0) + } + mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))}) + newData := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data))) + + if len(data) != len(newData) { + t.Fatalf("length mismatch. started with %d but got %d", len(data), len(newData)) + } + sim := cosineSimilarity(data[:], newData) + if sim < 0.99 { + t.Fatalf("failed similarity test: %f", sim) + } + t.Logf("similarity: %f", sim) + }) +} + +func dotProduct[V float32 | float64](v1, v2 []V) V { + var result V = 0 + for i := range v1 { + result += v1[i] * v2[i] + } + return result +} + +func magnitude[V float32 | float64](v []V) V { + var result V = 0 + for _, val := range v { + result += val * val + } + return V(math.Sqrt(float64(result))) +} + +func cosineSimilarity[V float32 | float64](v1, v2 []V) V { + return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2)) +} diff --git a/ml/backend/ggml/quantization.go b/ml/backend/ggml/quantization.go index bb31e455d..648ab74bb 100644 --- a/ml/backend/ggml/quantization.go +++ b/ml/backend/ggml/quantization.go @@ -44,6 +44,8 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 { C.dequantize_row_q6_K((*C.block_q6_K)(unsafe.Pointer(&data[0])), (*C.float)(&f32s[0]), elems) case C.GGML_TYPE_BF16: C.ggml_bf16_to_fp32_row((*C.ggml_bf16_t)(unsafe.Pointer(&data[0])), (*C.float)(&f32s[0]), elems) + case C.GGML_TYPE_MXFP4: + C.dequantize_row_mxfp4((*C.block_mxfp4)(unsafe.Pointer(&data[0])), (*C.float)(&f32s[0]), elems) default: panic("unsupported quantization format") } diff --git a/ml/nn/linear.go b/ml/nn/linear.go index 3985dd6c8..5bcde84de 100644 --- a/ml/nn/linear.go +++ b/ml/nn/linear.go @@ -15,3 +15,26 @@ func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor { return t } + +type LinearBatch struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor { + t = m.Weight.MulmatID(ctx, t, indices) + if m.Bias != nil { + var bias ml.Tensor + if len(indices.Shape()) > 1 { + // FIXME: Rows does not support 2D indices for a 2D input tensor so reshape indices to 1D. + bias = m.Bias.Rows(ctx, indices.Contiguous(ctx, indices.Dim(0)*indices.Dim(1))). + Duplicate(ctx). + Reshape(ctx, m.Bias.Dim(0), indices.Dim(0), indices.Dim(1)) + } else { + bias = m.Bias.Rows(ctx, indices) + } + t = t.Add(ctx, bias) + } + + return t +} diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/rope.go index b0c00a5b9..3b72d1cf9 100644 --- a/ml/nn/rope/rope.go +++ b/ml/nn/rope/rope.go @@ -4,9 +4,15 @@ import "github.com/ollama/ollama/ml" // Options contains optional parameters for RoPE function type Options struct { - OriginalContextLength int Type int Factors ml.Tensor + OriginalContextLength int + + // YaRN options + ExtrapolationFactor, + AttentionFactor, + BetaFast, + BetaSlow float32 } // WithOriginalContextLength sets a custom context length @@ -31,3 +37,15 @@ func WithFactors(factors ml.Tensor) func(*Options) { } } } + +func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) { + return func(opts *Options) { + opts.ExtrapolationFactor = extrapolationFactor + } +} + +func WithAttentionFactor(attentionFactor float32) func(*Options) { + return func(opts *Options) { + opts.AttentionFactor = attentionFactor + } +} diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index 246d2ba3e..7ade497da 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -22,7 +22,7 @@ var _ TextProcessor = (*BytePairEncoding)(nil) func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { return BytePairEncoding{ - pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), + pre: regexp2.MustCompile(pre, regexp2.None), vocab: vocab, } } diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go new file mode 100644 index 000000000..22b3e0794 --- /dev/null +++ b/model/models/gptoss/model.go @@ -0,0 +1,268 @@ +package gptoss + +import ( + "cmp" + "math" + "strings" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Transformer struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TransformerBlocks []TransformerBlock `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +// Forward implements model.Model. +func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + one := ctx.Input().FromFloatSlice([]float32{1}, 1) + for i, block := range m.TransformerBlocks { + m.Cache.SetLayer(i) + if c, ok := m.Cache.(*kvcache.WrapperCache); ok { + // Even layers are sliding window attention. + c.SetLayerType(i % 2) + } + + var outputs ml.Tensor + if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + } + + hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil +} + +type Options struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength, + numExperts, + numExpertsUsed, + originalContextLength int + + eps, + ropeBase, + ropeScale float32 +} + +func (o Options) RoPEOptions() []func(*rope.Options) { + return []func(*rope.Options){ + rope.WithTypeNeoX(), + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + // NOTE: ggml sets this implicitly so there's no need to set it here + // rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0), + } +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +type TransformerBlock struct { + Attention *AttentionBlock + MLP *MLPBlock +} + +func (d *TransformerBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs, one ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + } + + hiddenStates = d.MLP.Forward(ctx, hiddenStates, one, opts) + return hiddenStates +} + +type AttentionBlock struct { + Norm *nn.RMSNorm `gguf:"attn_norm"` + QKV *nn.Linear `gguf:"attn_qkv"` + Output *nn.Linear `gguf:"attn_out"` + Sinks ml.Tensor `gguf:"attn_sinks"` +} + +func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + residual := hiddenStates + hiddenStates = attn.Norm.Forward(ctx, hiddenStates, opts.eps) + + qkv := attn.QKV.Forward(ctx, hiddenStates) + + // query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim) + query := qkv.View(ctx, + 0, + opts.headDim(), qkv.Stride(0)*opts.headDim(), + opts.numHeads, qkv.Stride(1), + batchSize, + ) + query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + + // key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim) + key := qkv.View(ctx, + qkv.Stride(0)*opts.headDim()*opts.numHeads, + opts.headDim(), qkv.Stride(0)*opts.headDim(), + opts.numKVHeads, qkv.Stride(1), + batchSize, + ) + key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + + // value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim) + value := qkv.View(ctx, + qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads), + opts.headDim(), qkv.Stride(0)*opts.headDim(), + opts.numKVHeads, qkv.Stride(1), + batchSize, + ) + + cache.Put(ctx, key, value) + key, value, mask := cache.Get(ctx) + + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + + scores := key.MulmatFullPrec(ctx, query) + scores = scores.Scale(ctx, 1./math.Sqrt(float64(opts.headDim()))) + scores = scores.Add(ctx, mask) + + scores = scores.Concat(ctx, attn.Sinks.Reshape(ctx, 1, 1, opts.numHeads, 1).Repeat(ctx, 1, batchSize), 0) + scores = scores.Softmax(ctx) + scores = scores.Pad(ctx, -1, 0, 0, 0) + + attention := value.Mulmat(ctx, scores) + attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) + + return attn.Output.Forward(ctx, attention).Add(ctx, residual) +} + +type MLPBlock struct { + Norm *nn.RMSNorm `gguf:"ffn_norm"` + Router *nn.Linear `gguf:"ffn_gate_inp"` + GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` +} + +func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *Options) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + + residual := hiddenStates + hiddenStates = mlp.Norm.Forward(ctx, hiddenStates, opts.eps) + + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + routingWeights := mlp.Router.Forward(ctx, hiddenStates) + + selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, sequenceLength*batchSize).Rows(ctx, selectedExperts) + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, sequenceLength*batchSize).Softmax(ctx) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, sequenceLength*batchSize) + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts) + hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2)) + + dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)} + + glu := hiddenStates.View(ctx, 0, dimStride...) + glu = glu.Contiguous(ctx) + glu = glu.Clamp(ctx, float32(math.Inf(-1)), 7.0) + glu = glu.QuickGELU(ctx) + + linear := hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...) + linear = linear.Clamp(ctx, -7.0, 7.0) + + hiddenStates = glu.Mul(ctx, linear.Add(ctx, one)) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)) + + experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextStates.Add(ctx, residual) +} + +func New(c fs.Config) (model.Model, error) { + m := Transformer{ + TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", + strings.Join([]string{ + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `\p{N}{1,3}`, + ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, + `\s*[\r\n]+`, + `\s+(?!\S)`, + `\s+`, + }, "|"), + ), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1.), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + }, + } + + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWAMemCache(int32(c.Uint("attention.sliding_window")), 4096, m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + m.Cache.SetConfig(ml.CacheConfig{CachePadding: 32, PermutedV: true}) + return &m, nil +} + +func init() { + model.Register("gptoss", New) +} diff --git a/model/models/models.go b/model/models/models.go index 8752878e2..c880a4720 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -4,6 +4,7 @@ import ( _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" + _ "github.com/ollama/ollama/model/models/gptoss" _ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" diff --git a/openai/openai.go b/openai/openai.go index 35b8b9a01..d065de8f1 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -36,6 +36,7 @@ type ErrorResponse struct { type Message struct { Role string `json:"role"` Content any `json:"content"` + Reasoning string `json:"reasoning,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } @@ -81,6 +82,10 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage"` } +type Reasoning struct { + Effort *string `json:"effort,omitempty"` +} + type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -95,6 +100,7 @@ type ChatCompletionRequest struct { TopP *float64 `json:"top_p"` ResponseFormat *ResponseFormat `json:"response_format"` Tools []api.Tool `json:"tools"` + Reasoning *Reasoning `json:"reasoning,omitempty"` } type ChatCompletion struct { @@ -253,7 +259,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { SystemFingerprint: "fp_ollama", Choices: []Choice{{ Index: 0, - Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls}, + Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking}, FinishReason: func(reason string) *string { if len(toolCalls) > 0 { reason = "tool_calls" @@ -278,10 +284,10 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu SystemFingerprint: "fp_ollama", Choices: []ChunkChoice{{ Index: 0, - Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, + Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking}, FinishReason: func(reason string) *string { if len(reason) > 0 { - if toolCallSent { + if toolCallSent || len(toolCalls) > 0 { return &finishReasonToolCalls } return &reason @@ -397,7 +403,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { for _, msg := range r.Messages { switch content := msg.Content.(type) { case string: - messages = append(messages, api.Message{Role: msg.Role, Content: content}) + messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning}) case []any: for _, c := range content { data, ok := c.(map[string]any) @@ -508,6 +514,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { options["top_p"] = 1.0 } + if r.Reasoning != nil { + options["reasoning"] = *r.Reasoning.Effort + } + var format json.RawMessage if r.ResponseFormat != nil { switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) { @@ -521,6 +531,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } } + var think *api.ThinkValue + if r.Reasoning != nil { + think = &api.ThinkValue{ + Value: *r.Reasoning.Effort, + } + } + return &api.ChatRequest{ Model: r.Model, Messages: messages, @@ -528,6 +545,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { Options: options, Stream: &r.Stream, Tools: r.Tools, + Think: think, }, nil } diff --git a/server/harmonyparser.go b/server/harmonyparser.go new file mode 100644 index 000000000..fd6c64e73 --- /dev/null +++ b/server/harmonyparser.go @@ -0,0 +1,379 @@ +package server + +import ( + "context" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type harmonyParserState int + +const ( + harmonyParserState_LookingForMessageStart harmonyParserState = iota + harmonyParserState_ParsingHeader + harmonyParserState_ParsingContent +) + +func shouldUseHarmony(model Model) bool { + if model.Config.ModelFamily == "gptoss" { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { + return true + } + } + + return false +} + +func (s harmonyParserState) String() string { + switch s { + // we're looking for the message start tag + case harmonyParserState_LookingForMessageStart: + return "LookingForMessageStart" + case harmonyParserState_ParsingHeader: + return "ParsingHeader" + case harmonyParserState_ParsingContent: + return "ParsingContent" + default: + return "Unknown" + } +} + +type HarmonyParser struct { + state harmonyParserState + MessageStartTag string + MessageEndTag string + HeaderEndTag string + acc strings.Builder + lifetimeAcc strings.Builder +} + +type HarmonyEvent interface { + isHarmonyEvent() +} + +type HarmonyEventMessageStart struct{} + +func (HarmonyEventMessageStart) isHarmonyEvent() {} + +type HarmonyEventHeaderComplete struct { + Header HarmonyHeader +} + +func (HarmonyEventHeaderComplete) isHarmonyEvent() {} + +type HarmonyEventContentEmitted struct { + Content string +} + +func (HarmonyEventContentEmitted) isHarmonyEvent() {} + +type HarmonyEventMessageEnd struct{} + +func (HarmonyEventMessageEnd) isHarmonyEvent() {} + +type HarmonyHeader struct { + Role string + Channel string + Recipient string +} + +func (s *HarmonyParser) AddImplicitStart() { + s.acc.WriteString("<|start|>assistant") +} + +func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { + if lastMessage != nil && lastMessage.Role == "assistant" { + // handle prefilling conditions + if lastMessage.Content != "" { + s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") + return + } else if lastMessage.Thinking != "" { + s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") + return + } + } + s.AddImplicitStart() +} + +func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { + s.lifetimeAcc.WriteString(content) + s.acc.WriteString(content) + + var events []HarmonyEvent + + keepLooping := true + // we loop because we might pass through multiple parsing states in a single + // call to addContent, and we want to make sure callers don't have to wait for + // data that's already unambiguous + for keepLooping { + var newEvents []HarmonyEvent + newEvents, keepLooping = eat(s) + events = append(events, newEvents...) + } + + return events +} + +// the additional bool return is true iff we should continue eating +func eat(s *HarmonyParser) ([]HarmonyEvent, bool) { + switch s.state { + case harmonyParserState_LookingForMessageStart: + // does the acc contain the message start tag? + if strings.Contains(s.acc.String(), s.MessageStartTag) { + // split the acc into the message start tag and the rest + split := strings.SplitN(s.acc.String(), s.MessageStartTag, 2) + before := split[0] + if before != "" { + slog.Warn("harmony parser: found message start tag in the middle of the content", "content", s.acc.String()) + } + after := split[1] + s.acc.Reset() + s.acc.WriteString(after) + s.state = harmonyParserState_ParsingHeader + return []HarmonyEvent{HarmonyEventMessageStart{}}, true + } + + // no match, so we keep accumulating + return nil, false + case harmonyParserState_ParsingHeader: + if strings.Contains(s.acc.String(), s.HeaderEndTag) { + split := strings.SplitN(s.acc.String(), s.HeaderEndTag, 2) + header := split[0] + after := split[1] + s.acc.Reset() + s.acc.WriteString(after) + s.state = harmonyParserState_ParsingContent + return []HarmonyEvent{HarmonyEventHeaderComplete{Header: s.parseHeader(header)}}, true + } + return nil, false + case harmonyParserState_ParsingContent: + if strings.Contains(s.acc.String(), s.MessageEndTag) { + // if we already have the message end tag, we can emit the content up to it + split := strings.SplitN(s.acc.String(), s.MessageEndTag, 2) + content := split[0] + after := split[1] + s.acc.Reset() + s.acc.WriteString(after) + s.state = harmonyParserState_LookingForMessageStart + events := []HarmonyEvent{} + if content != "" { + events = append(events, HarmonyEventContentEmitted{Content: content}) + } + events = append(events, HarmonyEventMessageEnd{}) + return events, true + } else if overlapLen := overlap(s.acc.String(), s.MessageEndTag); overlapLen > 0 { + // if our suffix contains the start of the message end tag, we can emit + // the content up to the start of the message end tag + content := s.acc.String()[:len(s.acc.String())-overlapLen] + remaining := s.acc.String()[len(s.acc.String())-overlapLen:] + s.acc.Reset() + s.acc.WriteString(remaining) + // emit the content we know isn't part of the message end tag, and keep + // accumulating to disambiguate the rest + if content == "" { + return nil, false + } + return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false + } else { + // no end tag, so it's still normal content that we can immediately emit + content := s.acc.String() + if content == "" { + return nil, false + } + s.acc.Reset() + return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false + } + } + + return nil, false +} + +func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader { + harmonyHeader := HarmonyHeader{} + + // if `<|constrain|>` is present, ensure it has a space before it so it gets + // parsed as a separate token, even if the model didn't include the space + if strings.Contains(raw, "<|constrain|>") { + raw = strings.Replace(raw, "<|constrain|>", " <|constrain|>", 1) + raw = strings.TrimSpace(raw) + } + + // look for the optional channel tag, which is `<|channel|>` followed by the + // channel name, all without any whitespace + channelIndex := strings.Index(raw, "<|channel|>") + if channelIndex != -1 { + before := raw[:channelIndex] + after := raw[channelIndex+len("<|channel|>"):] + // the channel name is `after` all the way up to the first (if any) whitespace character + idx := strings.IndexFunc(after, func(r rune) bool { + return unicode.IsSpace(r) + }) + if idx == -1 { + idx = len(after) + } + harmonyHeader.Channel = after[:idx] + after = after[idx:] + // now we remove the channel tag from the raw string to further process + raw = before + after + raw = strings.TrimSpace(raw) + } + + // split the header into whitespace-separated tokens + tokens := strings.Fields(raw) + + // the first token is treated as the role + if len(tokens) == 0 { + slog.Error("harmony parser: missing role in header", "header", raw) + return harmonyHeader + } + role := tokens[0] + tokens = tokens[1:] + // special case: if role starts with to= then it's a tool call + if strings.HasPrefix(role, "to=") { + harmonyHeader.Recipient = role[3:] + harmonyHeader.Role = "tool" + } else { + harmonyHeader.Role = role + } + + // the recipient (if any) can be specified before or after the channel tag, so + // we check it at the end once we've already parsed the channel and role + if harmonyHeader.Recipient == "" && len(tokens) > 0 && strings.HasPrefix(tokens[0], "to=") { + harmonyHeader.Recipient = tokens[0][3:] + } + + return harmonyHeader +} + +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +// harmonyMessageState represents the current state of message processing +type harmonyMessageState int + +const ( + harmonyMessageState_Normal harmonyMessageState = iota + harmonyMessageState_Thinking + harmonyMessageState_ToolCalling +) + +// HarmonyMessageHandler processes harmony events and accumulates content appropriately. +// This is a higher level interface that maps harmony concepts into ollama concepts +type HarmonyMessageHandler struct { + state harmonyMessageState + harmonyParser *HarmonyParser +} + +// NewHarmonyMessageHandler creates a new message handler +func NewHarmonyMessageHandler() *HarmonyMessageHandler { + return &HarmonyMessageHandler{ + state: harmonyMessageState_Normal, + harmonyParser: &HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + }, + } +} + +// AddContent processes the content and returns the content, thinking, and tool content. +// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser +func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { + contentSb := strings.Builder{} + thinkingSb := strings.Builder{} + toolContentSb := strings.Builder{} + + events := h.harmonyParser.AddContent(content) + for _, event := range events { + switch event := event.(type) { + case HarmonyEventHeaderComplete: + slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header) + switch event.Header.Channel { + case "analysis": + if event.Header.Recipient != "" { + h.state = harmonyMessageState_ToolCalling + // event.Header.Recipient is the tool name, something like + // "browser.search" for a built-in, or "functions.calc" for a + // custom one + toolParser.SetToolName(event.Header.Recipient) + } else { + h.state = harmonyMessageState_Thinking + } + case "commentary": + if event.Header.Recipient != "" { + h.state = harmonyMessageState_ToolCalling + toolParser.SetToolName(event.Header.Recipient) + } else { + h.state = harmonyMessageState_Normal + } + case "final": + h.state = harmonyMessageState_Normal + } + case HarmonyEventContentEmitted: + slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state) + if h.state == harmonyMessageState_Normal { + contentSb.WriteString(event.Content) + } else if h.state == harmonyMessageState_Thinking { + thinkingSb.WriteString(event.Content) + } else if h.state == harmonyMessageState_ToolCalling { + toolContentSb.WriteString(event.Content) + } + case HarmonyEventMessageEnd: + h.state = harmonyMessageState_Normal + } + } + return contentSb.String(), thinkingSb.String(), toolContentSb.String() +} + +func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { + return &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + } +} + +type harmonyToolCallState int + +const ( + harmonyToolCallState_Normal harmonyToolCallState = iota + harmonyToolCallState_ToolCalling +) + +type HarmonyToolCallAccumulator struct { + state harmonyToolCallState + acc strings.Builder + currentToolName *string +} + +func (a *HarmonyToolCallAccumulator) SetToolName(toolName string) { + a.currentToolName = &toolName +} + +func (a *HarmonyToolCallAccumulator) Add(content string) { + a.acc.WriteString(content) +} + +func (a *HarmonyToolCallAccumulator) Drain() (*string, string) { + str := a.acc.String() + a.state = harmonyToolCallState_Normal + a.acc.Reset() + return a.currentToolName, str +} + +func (a *HarmonyToolCallAccumulator) Content() string { + return a.acc.String() +} diff --git a/server/harmonyparser_test.go b/server/harmonyparser_test.go new file mode 100644 index 000000000..cd1743e1c --- /dev/null +++ b/server/harmonyparser_test.go @@ -0,0 +1,469 @@ +package server + +import ( + "fmt" + "reflect" + "testing" +) + +func TestHeaderParsing(t *testing.T) { + tests := []struct { + in, wantRole, wantChannel, wantRecipient string + }{ + { + in: "assistant<|channel|>analysis", + wantRole: "assistant", + wantChannel: "analysis", + wantRecipient: "", + }, + { + in: "assistant<|channel|>analysis to=functions.get_weather", + wantRole: "assistant", + wantChannel: "analysis", + wantRecipient: "functions.get_weather", + }, + { + in: "assistant to=functions.get_weather<|channel|>analysis", + wantRole: "assistant", + wantChannel: "analysis", + wantRecipient: "functions.get_weather", + }, + // special case where the role is replaced by the recipient (matches reference code) + { + in: "to=functions.get_weather<|channel|>analysis", + wantRole: "tool", + wantChannel: "analysis", + wantRecipient: "functions.get_weather", + }, + // extra token after the recipient is ignored + { + in: "assistant to=functions.get_weather abc<|channel|>analysis", + wantRole: "assistant", + wantChannel: "analysis", + wantRecipient: "functions.get_weather", + }, + // with constrain tag, recipient after channel tag + { + in: "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json", + wantRole: "assistant", + wantChannel: "commentary", + wantRecipient: "functions.get_weather", + }, + // with constrain tag, recipient before channel tag + { + in: "assistant to=functions.get_weather<|channel|>commentary <|constrain|>json", + wantRole: "assistant", + wantChannel: "commentary", + wantRecipient: "functions.get_weather", + }, + // constrain tag without space + { + in: "assistant<|channel|>commentary to=functions.get_weather<|constrain|>json", + wantRole: "assistant", + wantChannel: "commentary", + wantRecipient: "functions.get_weather", + }, + // constrain tag without space, different order + { + in: "assistant to=functions.get_weather<|channel|>commentary<|constrain|>json", + wantRole: "assistant", + wantChannel: "commentary", + wantRecipient: "functions.get_weather", + }, + } + for i, tt := range tests { + parser := HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + header := parser.parseHeader(tt.in) + + if header.Role != tt.wantRole { + t.Errorf("case %d: got role \"%s\", want \"%s\"", i, header.Role, tt.wantRole) + } + if header.Channel != tt.wantChannel { + t.Errorf("case %d: got channel \"%s\", want \"%s\"", i, header.Channel, tt.wantChannel) + } + if header.Recipient != tt.wantRecipient { + t.Errorf("case %d: got recipient \"%s\", want \"%s\"", i, header.Recipient, tt.wantRecipient) + } + } +} + +func TestHarmonyParserHeaderEvent(t *testing.T) { + tests := []struct { + in, wantRole, wantChannel, wantRecipient string + implicitStart bool + }{ + { + in: "<|start|>user<|message|>What is 2 + 2?<|end|>", + wantRole: "user", + wantChannel: "", + wantRecipient: "", + }, + { + in: "<|start|>assistant<|channel|>analysis<|message|>What is 2 + 2?<|end|>", + wantRole: "assistant", + wantChannel: "analysis", + wantRecipient: "", + }, + { + in: "<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}<|call|><|start|>functions.get_weather to=assistant<|message|>{\"sunny\": true, \"temperature\": 20}<|end|>", + wantRole: "assistant", + wantChannel: "commentary", + wantRecipient: "functions.get_weather", + }, + { + in: "<|channel|>analysis<|message|>User asks weather in SF. We need location. Use get_current_weather with location \"San Francisco, CA\".<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"location\":\"San Francisco, CA\"}<|call|>", + wantRole: "assistant", + wantChannel: "analysis", + wantRecipient: "", + implicitStart: true, + }, + } + for i, tt := range tests { + parser := HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + if tt.implicitStart { + parser.AddImplicitStart() + } + gotEvents := parser.AddContent(tt.in) + if len(gotEvents) == 0 { + t.Errorf("case %d: got no events, want at least one", i) + } + + var firstHeaderEvent *HarmonyEventHeaderComplete + // print events + for _, event := range gotEvents { + fmt.Printf("event: %+v\n", event) + } + for _, event := range gotEvents { + if event, ok := event.(HarmonyEventHeaderComplete); ok { + firstHeaderEvent = &event + break + } + } + + if firstHeaderEvent == nil { + t.Errorf("case %d: got no header complete event, want one", i) + continue + } + gotHeader := firstHeaderEvent.Header + if gotHeader.Role != tt.wantRole || gotHeader.Channel != tt.wantChannel || gotHeader.Recipient != tt.wantRecipient { + t.Errorf("case %d: got header %+v, want role=%s channel=%s recipient=%s", i, gotHeader, tt.wantRole, tt.wantChannel, tt.wantRecipient) + } + } +} + +func TestHarmonyParserNonStreaming(t *testing.T) { + tests := []struct { + in string + implicitStart bool + wantEvents []HarmonyEvent + }{ + { + in: "<|start|>user<|message|>What is 2 + 2?<|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "What is 2 + 2?"}, + HarmonyEventMessageEnd{}, + }, + }, + { + in: "<|start|>assistant<|channel|>analysis<|message|>The answer is 4<|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "The answer is 4"}, + HarmonyEventMessageEnd{}, + }, + }, + { + in: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>Computing...<|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}}, + HarmonyEventContentEmitted{Content: "Computing..."}, + HarmonyEventMessageEnd{}, + }, + }, + { + in: "<|start|>user<|message|><|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}}, + HarmonyEventMessageEnd{}, + }, + }, + { + in: "<|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi!<|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "Hello"}, + HarmonyEventMessageEnd{}, + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "Hi!"}, + HarmonyEventMessageEnd{}, + }, + }, + { + in: "<|channel|>analysis<|message|>Thinking about the request<|end|>", + implicitStart: true, + wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}, HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, HarmonyEventContentEmitted{Content: "Thinking about the request"}, HarmonyEventMessageEnd{}}, + }, + } + for i, tt := range tests { + parser := HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + if tt.implicitStart { + parser.AddImplicitStart() + } + gotEvents := parser.AddContent(tt.in) + if !reflect.DeepEqual(gotEvents, tt.wantEvents) { + t.Errorf("case %d: got events %#v, want %#v", i, gotEvents, tt.wantEvents) + } + } +} + +func TestHarmonyParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []HarmonyEvent + } + + cases := []struct { + desc string + implicitStart bool + steps []step + }{ + { + desc: "simple message streamed character by character", + steps: []step{ + { + input: "<", + wantEvents: nil, + }, + { + input: "|", + wantEvents: nil, + }, + { + input: "start|>u", + wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}}, + }, + { + input: "ser<|mess", + wantEvents: nil, + }, + { + input: "age|>Hi", + wantEvents: []HarmonyEvent{ + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "Hi"}, + }, + }, + { + input: " there", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: " there"}}, + }, + { + input: "<|e", + wantEvents: nil, + }, + { + input: "nd|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + { + desc: "message with channel streamed", + steps: []step{ + { + input: "<|start|>assistant", + wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}}, + }, + { + input: "<|chan", + wantEvents: nil, + }, + { + input: "nel|>analysis", + wantEvents: nil, + }, + { + input: "<|message|>", + wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}}, + }, + { + input: "Thinking", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Thinking"}}, + }, + { + input: "...", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "..."}}, + }, + { + input: "<|end|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + { + desc: "message with channel and recipient", + steps: []step{ + { + input: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}}, + }, + }, + { + input: "{\"x\": 5}", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}}, + }, + { + input: "<|end|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + { + desc: "message with channel and recipient (receipient before channel)", + steps: []step{ + { + input: "<|start|>assistant to=functions.calc<|channel|>commentary<|message|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}}, + }, + }, + { + input: "{\"x\": 5}", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}}, + }, + { + input: "<|end|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + { + desc: "implicit start with channel", + implicitStart: true, + steps: []step{ + { + input: "<|channel|>thinking", + wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}}, + }, + { + input: "<|message|>", + wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "thinking", Recipient: ""}}}, + }, + { + input: "Processing request", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Processing request"}}, + }, + { + input: "<|end|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + { + desc: "multiple messages streamed", + steps: []step{ + { + input: "<|start|>user<|message|>Hello<|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "Hello"}, + HarmonyEventMessageEnd{}, + }, + }, + { + input: "<|start|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}}, + }, + { + input: "assistant<|message|>", + wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}}, + }, + { + input: "Hi!", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Hi!"}}, + }, + { + input: "<|end|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + { + desc: "empty message", + steps: []step{ + { + input: "<|start|>system<|message|><|end|>", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "system", Channel: "", Recipient: ""}}, + HarmonyEventMessageEnd{}, + }, + }, + }, + }, + { + desc: "partial tag that looks like end but isn't", + steps: []step{ + { + input: "<|start|>user<|message|>test<|e", + wantEvents: []HarmonyEvent{ + HarmonyEventMessageStart{}, + HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}}, + HarmonyEventContentEmitted{Content: "test"}, + }, + }, + { + input: "xample|>more", + wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "<|example|>more"}}, + }, + { + input: "<|end|>", + wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}}, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + parser := HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + if tc.implicitStart { + parser.AddImplicitStart() + } + + for i, step := range tc.steps { + gotEvents := parser.AddContent(step.input) + if !reflect.DeepEqual(gotEvents, step.wantEvents) { + t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents) + } + } + }) + } +} diff --git a/server/images.go b/server/images.go index 38505cc51..0c16dd435 100644 --- a/server/images.go +++ b/server/images.go @@ -111,7 +111,8 @@ func (m *Model) Capabilities() []model.Capability { // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) - if openingTag != "" && closingTag != "" { + hasTags := openingTag != "" && closingTag != "" + if hasTags || m.Config.ModelFamily == "gptoss" { capabilities = append(capabilities, model.CapabilityThinking) } diff --git a/server/prompt.go b/server/prompt.go index f8c895d71..5d6c3e27c 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) { var system []api.Message // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -42,11 +42,13 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } thinkVal := false + thinkLevel := "" if think != nil { - thinkVal = *think + thinkVal = think.AsBool() + thinkLevel = think.AsString() } var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { return "", nil, err } @@ -101,10 +103,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. // truncate any messages that do not fit into the context window var b bytes.Buffer thinkVal := false + thinkLevel := "" if think != nil { - thinkVal = *think + thinkVal = think.AsBool() + thinkLevel = think.AsString() } - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index 0043b9a47..659e64084 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -209,7 +209,7 @@ func TestChatPrompt(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} think := false - prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think) + prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}) if tt.error == nil && err != nil { t.Fatal(err) } else if tt.error != nil && err != tt.error { diff --git a/server/routes.go b/server/routes.go index 40348e737..991e92003 100644 --- a/server/routes.go +++ b/server/routes.go @@ -112,6 +112,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return nil, nil, nil, err } + // This model requires a minimum context to function effectively + if slices.Contains(model.Config.ModelFamilies, "gptoss") { + opts.NumCtx = max(opts.NumCtx, 8192) + } + runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) var runner *runnerRef select { @@ -182,11 +187,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + useHarmony := shouldUseHarmony(*m) && !req.Raw + var harmonyMessageHandler *HarmonyMessageHandler + var harmonyToolParser *HarmonyToolCallAccumulator + if useHarmony { + harmonyMessageHandler = NewHarmonyMessageHandler() + harmonyMessageHandler.harmonyParser.AddImplicitStart() + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + } + + // Validate Think value: string values currently only allowed for gptoss models + if req.Think != nil && req.Think.IsString() && !useHarmony { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.AsString())}) + return + } + caps := []model.Capability{model.CapabilityCompletion} if req.Suffix != "" { caps = append(caps, model.CapabilityInsert) } - if req.Think != nil && *req.Think { + if req.Think != nil && req.Think.AsBool() { caps = append(caps, model.CapabilityThinking) // TODO(drifkin): consider adding a warning if it's false and the model // doesn't support thinking. It's not strictly required, but it can be a @@ -261,7 +281,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) } - values.Think = req.Think != nil && *req.Think + values.Think = req.Think != nil && req.Think.AsBool() + values.ThinkLevel = "" + if req.Think != nil { + values.ThinkLevel = req.Think.AsString() + } values.IsThinkSet = req.Think != nil var b bytes.Buffer @@ -284,11 +308,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { } var thinkingState *thinking.Parser - openingTag, closingTag := thinking.InferTags(m.Template.Template) - if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { - thinkingState = &thinking.Parser{ - OpeningTag: openingTag, - ClosingTag: closingTag, + if !useHarmony { + openingTag, closingTag := thinking.InferTags(m.Template.Template) + if req.Think != nil && req.Think.AsBool() && openingTag != "" && closingTag != "" { + thinkingState = &thinking.Parser{ + OpeningTag: openingTag, + ClosingTag: closingTag, + } } } @@ -316,7 +342,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if thinkingState != nil { + if useHarmony { + content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + res.Response = content + res.Thinking = thinking + harmonyToolParser.Add(toolContent) + } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content @@ -327,6 +358,25 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + if useHarmony { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + ch <- gin.H{"error parsing tool call": err.Error()} + return + } + + res.ToolCalls = append(res.ToolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -341,6 +391,15 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } + if useHarmony { + // only send messages with meaningful content (empty messages confuse clients) + if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { + ch <- res + } + + return + } + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} @@ -1471,7 +1530,7 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) } - if req.Think != nil && *req.Think { + if req.Think != nil && req.Think.AsBool() { caps = append(caps, model.CapabilityThinking) } @@ -1521,9 +1580,30 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + useHarmony := shouldUseHarmony(*m) + + // Validate Think value: string values currently only allowed for gptoss models + if req.Think != nil && req.Think.IsString() && !useHarmony { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.AsString())}) + return + } + + var harmonyMessageHandler *HarmonyMessageHandler + var harmonyToolParser *HarmonyToolCallAccumulator + + if useHarmony { + harmonyMessageHandler = NewHarmonyMessageHandler() + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + } + var thinkingState *thinking.Parser openingTag, closingTag := thinking.InferTags(m.Template.Template) - if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { + if req.Think != nil && req.Think.AsBool() && openingTag != "" && closingTag != "" { thinkingState = &thinking.Parser{ OpeningTag: openingTag, ClosingTag: closingTag, @@ -1531,7 +1611,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } var toolParser *tools.Parser - if len(req.Tools) > 0 { + if len(req.Tools) > 0 && !useHarmony { toolParser = tools.NewParser(m.Template.Template, req.Tools) } @@ -1557,6 +1637,38 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + if r.Done { + res.DoneReason = r.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } + + if useHarmony { + content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) + res.Message.Content = content + res.Message.Thinking = thinking + harmonyToolParser.Add(toolContent) + + if r.Done { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + ch <- gin.H{"error parsing tool call": err.Error()} + return + } + res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} + } + } + + // only send messages with meaningful content (empty messages confuse clients) + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { + ch <- res + } + + return + } if thinkingState != nil { thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content) @@ -1568,12 +1680,6 @@ func (s *Server) ChatHandler(c *gin.Context) { res.Message.Thinking = thinkingContent } - if r.Done { - res.DoneReason = r.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - } - if len(req.Tools) > 0 { toolCalls, content := toolParser.Add(res.Message.Content) if len(content) > 0 { diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 75a246fc6..477d6b814 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -150,7 +150,7 @@ func TestGenerateChat(t *testing.T) { Messages: []api.Message{ {Role: "user", Content: "Hello!"}, }, - Think: &think, + Think: &api.ThinkValue{Value: think}, }) if w.Code != http.StatusBadRequest { diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go new file mode 100644 index 000000000..503cb4d74 --- /dev/null +++ b/server/routes_harmony_streaming_test.go @@ -0,0 +1,712 @@ +package server + +// this test file is to test integration of harmony parser into routes.go (as +// opposed to harmonyparser_test.go, which tests the parser in isolation) + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/discover" + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/llm" +) + +func getTestTools() []api.Tool { + return []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state, e.g. San Francisco, CA", + }, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "calculate", + Description: "Calculate a mathematical expression", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Required: []string{"expression"}, + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "expression": { + Type: api.PropertyType{"string"}, + Description: "The mathematical expression to calculate", + }, + }, + }, + }, + }, + } +} + +func createHarmonyTestModel(t *testing.T) (string, string) { + t.Helper() + + return createBinFile(t, ggml.KV{ + "general.architecture": "gptoss", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) +} + +// TestChatHarmonyParserStreamingRealtime verifies that chunks are emitted as soon as they're available +func TestChatHarmonyParserStreamingRealtime(t *testing.T) { + gin.SetMode(gin.TestMode) + + type step struct { + input llm.CompletionResponse + wantContent string + wantThinking string + wantToolCalls []api.ToolCall + } + + testCases := []struct { + name string + steps []step + only bool + }{ + { + name: "content streams as it arrives", + steps: []step{ + { + input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false}, + wantContent: "Hello", + }, + { + input: llm.CompletionResponse{Content: ", world", Done: false}, + wantContent: ", world", + }, + { + input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "!", + }, + }, + }, + { + name: "thinking streams separately from content", + steps: []step{ + { + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false}, + wantThinking: "Thinking...", + }, + { + input: llm.CompletionResponse{Content: "<|end|>", Done: false}, + // No output expected - just closes the analysis message and resets state to normal + }, + { + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false}, + wantContent: "Answer", // After message end, state is reset to normal + }, + { + input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + // No output expected - just closes the assistant message + }, + }, + }, + { + name: "partial tags buffer until complete", + steps: []step{ + { + input: llm.CompletionResponse{Content: "<|chan", Done: false}, + // No output - partial tag + }, + { + input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false}, + // No output - still building tags + }, + { + input: llm.CompletionResponse{Content: "age|>Deep ", Done: false}, + wantThinking: "Deep ", + }, + { + input: llm.CompletionResponse{Content: "thought<|end|>", Done: false}, + wantThinking: "thought", + }, + { + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Done", // After message end, state is reset to normal + }, + }, + }, + { + name: "simple assistant after analysis", + steps: []step{ + { + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Answer", + wantThinking: "Think", + }, + }, + }, + { + name: "tool call parsed and returned correctly", + steps: []step{ + { + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "The weather is sunny", + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "San Francisco", + }, + }, + }, + }, + }, + }, + }, + { + name: "tool call with streaming JSON across chunks", + steps: []step{ + { + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false}, + // No output yet - incomplete JSON + }, + { + input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false}, + // Still no output - incomplete JSON + }, + { + input: llm.CompletionResponse{Content: "2\"}", Done: true}, + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: api.ToolCallFunctionArguments{ + "expression": "2+2", + }, + }, + }, + }, + }, + }, + }, + } + + anyOnlies := false + for _, tc := range testCases { + if tc.only { + anyOnlies = true + } + } + + for _, tc := range testCases { + if anyOnlies && !tc.only { + continue + } + + t.Run(tc.name, func(t *testing.T) { + var chunks []api.ChatResponse + chunkIdx := 0 + + mockResponses := make([]llm.CompletionResponse, len(tc.steps)) + for i, step := range tc.steps { + mockResponses[i] = step.input + } + + mock := mockRunner{ + CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + for _, resp := range mockResponses { + fn(resp) + // Give the handler time to process each response + time.Sleep(30 * time.Millisecond) + } + return nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 100 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a simple test model + _, digest := createHarmonyTestModel(t) + + streamFalse := false + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "harmony-test-streaming", + Files: map[string]string{"test.gguf": digest}, + Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, + Stream: &streamFalse, + }) + + if w.Code != 200 { + t.Fatalf("failed to create model: %d", w.Code) + } + + // Test chat endpoint with streaming + streamTrue := true + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "harmony-test-streaming", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Stream: &streamTrue, + Tools: getTestTools(), + }) + + if w.Code != 200 { + t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) + } + + // Parse all chunks + decoder := json.NewDecoder(w.Body) + for decoder.More() { + var chunk api.ChatResponse + if err := decoder.Decode(&chunk); err != nil { + t.Fatalf("failed to decode chunk: %v", err) + } + if chunk.Message.Content != "" || chunk.Message.Thinking != "" || len(chunk.Message.ToolCalls) > 0 { + chunks = append(chunks, chunk) + } + } + + // Log received chunks for debugging + if t.Failed() || len(chunks) == 0 { + t.Logf("Received %d chunks:", len(chunks)) + for i, chunk := range chunks { + t.Logf(" Chunk %d: content=%q thinking=%q", i, chunk.Message.Content, chunk.Message.Thinking) + } + } + + // Verify chunks match expected steps + for i, step := range tc.steps { + // Skip steps that don't expect any output + if step.wantContent == "" && step.wantThinking == "" && len(step.wantToolCalls) == 0 { + continue + } + + if chunkIdx >= len(chunks) { + t.Errorf("step %d: expected chunk not received (wanted content=%q thinking=%q)", + i, step.wantContent, step.wantThinking) + continue + } + + chunk := chunks[chunkIdx] + if chunk.Message.Content != step.wantContent || chunk.Message.Thinking != step.wantThinking { + t.Errorf("step %d: chunk mismatch: got (content=%q, thinking=%q), want (content=%q, thinking=%q)", + i, chunk.Message.Content, chunk.Message.Thinking, step.wantContent, step.wantThinking) + } + + // Check tool calls if expected + if len(step.wantToolCalls) > 0 { + if len(chunk.Message.ToolCalls) != len(step.wantToolCalls) { + t.Errorf("step %d: tool calls count mismatch: got %d, want %d", + i, len(chunk.Message.ToolCalls), len(step.wantToolCalls)) + } else { + for j, wantCall := range step.wantToolCalls { + if j >= len(chunk.Message.ToolCalls) { + break + } + gotCall := chunk.Message.ToolCalls[j] + if gotCall.Function.Name != wantCall.Function.Name { + t.Errorf("step %d, tool call %d: name mismatch: got %q, want %q", + i, j, gotCall.Function.Name, wantCall.Function.Name) + } + // Compare arguments as JSON strings for simplicity + gotArgs, _ := json.Marshal(gotCall.Function.Arguments) + wantArgs, _ := json.Marshal(wantCall.Function.Arguments) + if string(gotArgs) != string(wantArgs) { + t.Errorf("step %d, tool call %d: arguments mismatch: got %s, want %s", + i, j, string(gotArgs), string(wantArgs)) + } + } + } + } + chunkIdx++ + } + + // Check if we have extra chunks + if chunkIdx < len(chunks) { + t.Errorf("received %d extra chunks", len(chunks)-chunkIdx) + for i := chunkIdx; i < len(chunks); i++ { + t.Logf(" extra chunk %d: content=%q thinking=%q", + i-chunkIdx, chunks[i].Message.Content, chunks[i].Message.Thinking) + } + } + }) + } +} + +// TestChatHarmonyParserStreamingSimple is a simpler test that just verifies basic streaming +func TestChatHarmonyParserStreamingSimple(t *testing.T) { + gin.SetMode(gin.TestMode) + + mockResponses := []llm.CompletionResponse{ + {Content: "<|message|>First ", Done: false}, + {Content: "chunk ", Done: false}, + {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + } + + mock := mockRunner{ + CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + t.Logf("Mock received prompt: %q", r.Prompt) + t.Logf("Mock sending %d responses", len(mockResponses)) + for i, resp := range mockResponses { + t.Logf("Sending response %d: %q", i, resp.Content) + fn(resp) + } + return nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 100 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create model + _, digest := createHarmonyTestModel(t) + streamFalse := false + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "gpt-oss", + Files: map[string]string{"test.gguf": digest}, + Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`, + Stream: &streamFalse, + }) + + if w.Code != 200 { + t.Fatalf("failed to create model: %d", w.Code) + } + + // Test streaming + streamTrue := true + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "gpt-oss", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Stream: &streamTrue, + Tools: getTestTools(), + }) + + if w.Code != 200 { + t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) + } + + // Parse chunks + var chunks []api.ChatResponse + decoder := json.NewDecoder(w.Body) + for decoder.More() { + var chunk api.ChatResponse + if err := decoder.Decode(&chunk); err != nil { + t.Fatalf("failed to decode chunk: %v", err) + } + chunks = append(chunks, chunk) + t.Logf("Received chunk %d: content=%q thinking=%q done=%v", + len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) + } + + // Verify we got chunks + if len(chunks) == 0 { + t.Fatal("expected streaming chunks, got none") + } + + // Verify content + var content strings.Builder + for _, chunk := range chunks { + content.WriteString(chunk.Message.Content) + } + + expectedContent := "First chunk here" + if content.String() != expectedContent { + t.Errorf("content mismatch: got %q, want %q", content.String(), expectedContent) + } + + // Verify we got multiple chunks (streaming) + contentChunks := 0 + for _, chunk := range chunks { + if chunk.Message.Content != "" { + contentChunks++ + } + } + + if contentChunks < 2 { + t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) + } +} + +func TestChatHarmonyParserStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + type expectedChunk struct { + afterResponse int // Which mock response this chunk should appear after + content string // Expected content in this chunk + thinking string // Expected thinking in this chunk + } + + testCases := []struct { + name string + mockResponses []llm.CompletionResponse + expectedChunks []expectedChunk + wantContent string + wantThinking string + }{ + { + name: "simple message without thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|start|>assistant<|message|>Hello, ", Done: false}, + {Content: "how can I help?", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 1, content: "Hello, "}, + {afterResponse: 2, content: "how can I help?"}, + }, + wantContent: "Hello, how can I help?", + }, + { + name: "message with analysis channel for thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|channel|>analysis<|message|>", Done: false}, + {Content: "Let me think ", Done: false}, + {Content: "about this problem...", Done: false}, + {Content: "<|end|>", Done: false}, + {Content: "<|start|>assistant<|message|>", Done: false}, + {Content: "The answer ", Done: false}, + {Content: "is 42", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 2, thinking: "Let me think "}, + {afterResponse: 3, thinking: "about this problem..."}, + {afterResponse: 6, content: "The answer "}, + {afterResponse: 7, content: "is 42"}, + }, + wantContent: "The answer is 42", + wantThinking: "Let me think about this problem...", + }, + { + name: "streaming with partial tags across boundaries", + mockResponses: []llm.CompletionResponse{ + {Content: "<|chan", Done: false}, + {Content: "nel|>analy", Done: false}, + {Content: "sis<|mess", Done: false}, + {Content: "age|>Think", Done: false}, + {Content: "ing deeply...<|end|>", Done: false}, + {Content: "<|start|>assi", Done: false}, + {Content: "stant<|message|>Result ", Done: false}, + {Content: "computed<|e", Done: false}, + {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 4, thinking: "Think"}, + {afterResponse: 5, thinking: "ing deeply..."}, + {afterResponse: 7, content: "Result "}, + {afterResponse: 8, content: "computed"}, + }, + wantContent: "Result computed", + wantThinking: "Thinking deeply...", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Channel to synchronize mock responses with chunk verification + responsesSent := make(chan int, len(tc.mockResponses)) + + mock := mockRunner{ + CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + // Send mock responses one at a time, notifying when each is sent + for i, resp := range tc.mockResponses { + fn(resp) + responsesSent <- i + 1 + } + close(responsesSent) + return nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a minimal model + _, digest := createHarmonyTestModel(t) + + // Create model with passthrough template + stream := false + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "harmony-test", + Files: map[string]string{"file.gguf": digest}, + Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("failed to create model: %d", w.Code) + } + + // Test chat endpoint with streaming + streamTrue := true + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "harmony-test", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Stream: &streamTrue, + Tools: getTestTools(), + }) + + if w.Code != http.StatusOK { + t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) + } + + // Parse streaming response + var chunks []api.ChatResponse + var content, thinking strings.Builder + + decoder := json.NewDecoder(w.Body) + for decoder.More() { + var chunk api.ChatResponse + if err := decoder.Decode(&chunk); err != nil { + t.Fatalf("failed to decode chunk: %v", err) + } + chunks = append(chunks, chunk) + + // Accumulate content and thinking from each chunk + content.WriteString(chunk.Message.Content) + thinking.WriteString(chunk.Message.Thinking) + + // Debug output + t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) + } + + // Verify we got streaming chunks + if len(chunks) == 0 { + t.Fatal("expected streaming chunks, got none") + } + + gotContent := content.String() + gotThinking := thinking.String() + + if gotContent != tc.wantContent { + t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent) + } + if gotThinking != tc.wantThinking { + t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking) + } + + // Verify last chunk has done=true + lastChunk := chunks[len(chunks)-1] + if !lastChunk.Done { + t.Error("expected last chunk to have done=true") + } + }) + } +} diff --git a/template/template.go b/template/template.go index d28ace413..bfd02a92d 100644 --- a/template/template.go +++ b/template/template.go @@ -13,6 +13,7 @@ import ( "sync" "text/template" "text/template/parse" + "time" "github.com/agnivade/levenshtein" @@ -121,6 +122,11 @@ var funcs = template.FuncMap{ b, _ := json.Marshal(v) return string(b) }, + "currentDate": func(args ...string) string { + // Currently ignoring the format argument, but accepting it for future use + // Default format is YYYY-MM-DD + return time.Now().Format("2006-01-02") + }, } func Parse(s string) (*Template, error) { @@ -160,12 +166,18 @@ func (t *Template) Vars() []string { return slices.Sorted(maps.Keys(set)) } +func (t *Template) Contains(s string) bool { + return strings.Contains(t.raw, s) +} + type Values struct { Messages []api.Message api.Tools Prompt string Suffix string Think bool + // ThinkLevel contains the thinking level if Think is true and a string value was provided + ThinkLevel string // whether or not the user explicitly set the thinking flag (vs. it being // implicitly false). Templates can't see whether `Think` is nil IsThinkSet bool @@ -228,6 +240,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "Suffix": v.Suffix, "Response": "", "Think": v.Think, + "ThinkLevel": v.ThinkLevel, "IsThinkSet": v.IsThinkSet, }) } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { @@ -237,6 +250,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "Tools": v.Tools, "Response": "", "Think": v.Think, + "ThinkLevel": v.ThinkLevel, "IsThinkSet": v.IsThinkSet, }) } @@ -251,6 +265,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "Prompt": prompt, "Response": response, "Think": v.Think, + "ThinkLevel": v.ThinkLevel, "IsThinkSet": v.IsThinkSet, }); err != nil { return err @@ -298,6 +313,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "Prompt": prompt, "Response": response, "Think": v.Think, + "ThinkLevel": v.ThinkLevel, "IsThinkSet": v.IsThinkSet, }); err != nil { return err diff --git a/tools/tools.go b/tools/tools.go index f473ab6a6..f9ca15530 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -26,6 +26,10 @@ type Parser struct { n int } +func (p *Parser) GetBuffer() []byte { + return p.buffer +} + // NewParser creates a new tool call parser from a model's chat // template and a list of provided tools. func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {