add pre:, suf: to tags (#12274)

This commit is contained in:
Michael Yang 2025-09-23 16:08:57 -07:00 committed by GitHub
parent a40d427bce
commit bf78ed6ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@ -171,35 +172,42 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
// make a copy // make a copy
tagsCopy := tags tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" { if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, ParseTags(tag)) tagsCopy = append(tagsCopy, parseTag(tag))
} }
if tt == reflect.TypeOf((*Base)(nil)).Elem() { if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base)) vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag) [][]string var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag) (names [][]string) { fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 { if len(tags) > 0 {
localNames := []string{tags[0].Name} var names []string
localNames = append(localNames, tags[0].Alternate...) if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix)
}
}
for _, localName := range localNames { if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 {
fullName := []string{localName} // no child names, append current names
nested := fn(tags[1:]) fullNames = append(fullNames, names)
if len(nested) > 0 { } else if len(names) == 0 {
for _, rest := range nested { // no current names, append child names
names = append(names, append(fullName, rest...)) fullNames = append(fullNames, childNames...)
} else {
// combine current and child names
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
} }
} else {
names = append(names, fullName)
} }
} }
} }
return names return fullNames
} }
names := fn(tagsCopy) names := fn(tagsCopy, "", "")
for _, name := range names { for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor) logutil.Trace("found tensor", "", tensor)
@ -213,9 +221,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
for i := range vv.Len() { for i := range vv.Len() {
vvv := vv.Index(i) vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else { } else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
} }
} }
} }
@ -254,18 +262,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
} }
type Tag struct { type Tag struct {
Name string name,
Alternate []string // prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
} }
func ParseTags(s string) (tag Tag) { func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",") parts := strings.Split(s, ",")
if len(parts) > 0 { if len(parts) > 0 {
tag.Name = parts[0] tag.name = parts[0]
for _, part := range parts[1:] { for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok { if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
tag.Alternate = append(tag.Alternate, value) // elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
} }
} }
} }

View File

@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) {
{ {
value: "output", value: "output",
want: Tag{ want: Tag{
Name: "output", name: "output",
}, },
}, },
{ {
value: "output,alt:token_embd", value: "output,alt:token_embd",
want: Tag{ want: Tag{
Name: "output", name: "output",
Alternate: []string{ alternatives: []string{
"token_embd", "token_embd",
}, },
}, },
@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) {
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) { t.Run(tt.value, func(t *testing.T) {
got := ParseTags(tt.value) got := parseTag(tt.value)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" {
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
} }
}) })
@ -147,6 +147,57 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
} }
} }
func TestPopulateFieldsPrefixSuffixName(t *testing.T) {
type fakeBlock struct {
A *nn.Linear `gguf:"a"`
B *nn.Linear `gguf:",pre:b_"`
C *nn.Linear `gguf:",suf:_c"`
XY *nn.Linear `gguf:",pre:x_,suf:_y"`
}
type fakeModel struct {
Blocks []fakeBlock `gguf:"blk"`
}
m := fakeModel{
Blocks: make([]fakeBlock, 2),
}
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{
"blk.0.a.weight",
"blk.0.b_weight",
"blk.0.b_bias",
"blk.0.weight_c",
"blk.0.x_weight_y",
"blk.1.a.weight",
"blk.1.b_weight",
"blk.1.b_bias",
"blk.1.weight_c",
"blk.1.x_weight_y",
},
}}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Blocks: []fakeBlock{
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
},
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
},
},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
func TestModelForArch(t *testing.T) { func TestModelForArch(t *testing.T) {
type fakeModel struct { type fakeModel struct {
Model Model

View File

@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
return nextStates return nextStates
} }
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct { type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"` Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts Experts *TextExperts
SharedExpert *TextSharedExpert SharedExpert *TextMLP `gguf:",suf:_shexp"`
} }
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {