mirror of https://github.com/ollama/ollama.git
add pre:, suf: to tags (#12274)
This commit is contained in:
parent
a40d427bce
commit
bf78ed6ee9
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -171,35 +172,42 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
// make a copy
|
// make a copy
|
||||||
tagsCopy := tags
|
tagsCopy := tags
|
||||||
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||||
tagsCopy = append(tagsCopy, ParseTags(tag))
|
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||||
vv.Set(reflect.ValueOf(base))
|
vv.Set(reflect.ValueOf(base))
|
||||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||||
var fn func([]Tag) [][]string
|
var fn func([]Tag, string, string) [][]string
|
||||||
fn = func(tags []Tag) (names [][]string) {
|
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||||
if len(tags) > 0 {
|
if len(tags) > 0 {
|
||||||
localNames := []string{tags[0].Name}
|
var names []string
|
||||||
localNames = append(localNames, tags[0].Alternate...)
|
if tags[0].name != "" {
|
||||||
|
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
|
||||||
for _, localName := range localNames {
|
names = append(names, prefix+n+suffix)
|
||||||
fullName := []string{localName}
|
|
||||||
nested := fn(tags[1:])
|
|
||||||
if len(nested) > 0 {
|
|
||||||
for _, rest := range nested {
|
|
||||||
names = append(names, append(fullName, rest...))
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 {
|
||||||
|
// no child names, append current names
|
||||||
|
fullNames = append(fullNames, names)
|
||||||
|
} else if len(names) == 0 {
|
||||||
|
// no current names, append child names
|
||||||
|
fullNames = append(fullNames, childNames...)
|
||||||
} else {
|
} else {
|
||||||
names = append(names, fullName)
|
// combine current and child names
|
||||||
|
for _, name := range names {
|
||||||
|
for _, childName := range childNames {
|
||||||
|
fullNames = append(fullNames, append([]string{name}, childName...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return names
|
return fullNames
|
||||||
}
|
}
|
||||||
|
|
||||||
names := fn(tagsCopy)
|
names := fn(tagsCopy, "", "")
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||||
logutil.Trace("found tensor", "", tensor)
|
logutil.Trace("found tensor", "", tensor)
|
||||||
|
@ -213,9 +221,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
for i := range vv.Len() {
|
for i := range vv.Len() {
|
||||||
vvv := vv.Index(i)
|
vvv := vv.Index(i)
|
||||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||||
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
|
||||||
} else {
|
} else {
|
||||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -254,18 +262,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tag struct {
|
type Tag struct {
|
||||||
Name string
|
name,
|
||||||
Alternate []string
|
// prefix and suffix are applied to child tags
|
||||||
|
prefix,
|
||||||
|
suffix string
|
||||||
|
alternatives []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseTags(s string) (tag Tag) {
|
func parseTag(s string) (tag Tag) {
|
||||||
parts := strings.Split(s, ",")
|
parts := strings.Split(s, ",")
|
||||||
if len(parts) > 0 {
|
if len(parts) > 0 {
|
||||||
tag.Name = parts[0]
|
tag.name = parts[0]
|
||||||
|
|
||||||
for _, part := range parts[1:] {
|
for _, part := range parts[1:] {
|
||||||
if value, ok := strings.CutPrefix(part, "alt:"); ok {
|
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
|
||||||
tag.Alternate = append(tag.Alternate, value)
|
// elevate alternative to primary if no primary given
|
||||||
|
tag.name = value
|
||||||
|
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
|
||||||
|
} else if ok {
|
||||||
|
tag.alternatives = append(tag.alternatives, value)
|
||||||
|
}
|
||||||
|
if value, ok := strings.CutPrefix(part, "pre:"); ok {
|
||||||
|
tag.prefix = value
|
||||||
|
}
|
||||||
|
if value, ok := strings.CutPrefix(part, "suf:"); ok {
|
||||||
|
tag.suffix = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) {
|
||||||
{
|
{
|
||||||
value: "output",
|
value: "output",
|
||||||
want: Tag{
|
want: Tag{
|
||||||
Name: "output",
|
name: "output",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
value: "output,alt:token_embd",
|
value: "output,alt:token_embd",
|
||||||
want: Tag{
|
want: Tag{
|
||||||
Name: "output",
|
name: "output",
|
||||||
Alternate: []string{
|
alternatives: []string{
|
||||||
"token_embd",
|
"token_embd",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.value, func(t *testing.T) {
|
t.Run(tt.value, func(t *testing.T) {
|
||||||
got := ParseTags(tt.value)
|
got := parseTag(tt.value)
|
||||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" {
|
||||||
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
|
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -147,6 +147,57 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPopulateFieldsPrefixSuffixName(t *testing.T) {
|
||||||
|
type fakeBlock struct {
|
||||||
|
A *nn.Linear `gguf:"a"`
|
||||||
|
B *nn.Linear `gguf:",pre:b_"`
|
||||||
|
C *nn.Linear `gguf:",suf:_c"`
|
||||||
|
XY *nn.Linear `gguf:",pre:x_,suf:_y"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeModel struct {
|
||||||
|
Blocks []fakeBlock `gguf:"blk"`
|
||||||
|
}
|
||||||
|
|
||||||
|
m := fakeModel{
|
||||||
|
Blocks: make([]fakeBlock, 2),
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(&m)
|
||||||
|
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||||
|
names: []string{
|
||||||
|
"blk.0.a.weight",
|
||||||
|
"blk.0.b_weight",
|
||||||
|
"blk.0.b_bias",
|
||||||
|
"blk.0.weight_c",
|
||||||
|
"blk.0.x_weight_y",
|
||||||
|
"blk.1.a.weight",
|
||||||
|
"blk.1.b_weight",
|
||||||
|
"blk.1.b_bias",
|
||||||
|
"blk.1.weight_c",
|
||||||
|
"blk.1.x_weight_y",
|
||||||
|
},
|
||||||
|
}}, v.Elem()))
|
||||||
|
|
||||||
|
if diff := cmp.Diff(fakeModel{
|
||||||
|
Blocks: []fakeBlock{
|
||||||
|
{
|
||||||
|
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
|
||||||
|
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
|
||||||
|
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
|
||||||
|
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
|
||||||
|
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
|
||||||
|
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
|
||||||
|
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, m); diff != "" {
|
||||||
|
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModelForArch(t *testing.T) {
|
func TestModelForArch(t *testing.T) {
|
||||||
type fakeModel struct {
|
type fakeModel struct {
|
||||||
Model
|
Model
|
||||||
|
|
|
@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||||
return nextStates
|
return nextStates
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextSharedExpert is TextMLP with different tensor names
|
|
||||||
type TextSharedExpert struct {
|
|
||||||
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
|
|
||||||
Up *nn.Linear `gguf:"ffn_up_shexp"`
|
|
||||||
Down *nn.Linear `gguf:"ffn_down_shexp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextMOE struct {
|
type TextMOE struct {
|
||||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
Experts *TextExperts
|
Experts *TextExperts
|
||||||
SharedExpert *TextSharedExpert
|
SharedExpert *TextMLP `gguf:",suf:_shexp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
|
|
Loading…
Reference in New Issue