mirror of https://github.com/ollama/ollama.git
160 lines
5.0 KiB
Go
160 lines
5.0 KiB
Go
package ggml
|
|
|
|
import (
|
|
"maps"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
func TestTensorLayers(t *testing.T) {
|
|
tensors := make(map[string]*Tensor)
|
|
for _, name := range []string{
|
|
"token_embd.weight",
|
|
"blk.0.attn_k.weight",
|
|
"blk.0.attn_output.weight",
|
|
"blk.0.attn_q.weight",
|
|
"blk.0.attn_v.weight",
|
|
"blk.0.attn_norm.weight",
|
|
"blk.0.ffn_down.weight",
|
|
"blk.0.ffn_gate.weight",
|
|
"blk.0.ffn_up.weight",
|
|
"blk.0.ffn_norm.weight",
|
|
"output_norm.weight",
|
|
"mm.0.bias",
|
|
"mm.0.weight",
|
|
"v.blk.0.attn_k.weight",
|
|
"v.blk.0.attn_output.weight",
|
|
"v.blk.0.attn_q.weight",
|
|
"v.blk.0.attn_v.weight",
|
|
"v.blk.0.attn_norm.weight",
|
|
"v.blk.0.ffn_down.weight",
|
|
"v.blk.0.ffn_gate.weight",
|
|
"v.blk.0.ffn_up.weight",
|
|
"v.blk.0.ffn_norm.weight",
|
|
"v.patch_embd.weight",
|
|
"v.position_embd.gate",
|
|
"v.position_embd.weight",
|
|
} {
|
|
tensors[name] = &Tensor{Name: name}
|
|
}
|
|
|
|
cases := []struct {
|
|
name string
|
|
items []*Tensor
|
|
want map[string]Layer
|
|
}{
|
|
{
|
|
name: "text",
|
|
items: slices.Collect(func(yield func(*Tensor) bool) {
|
|
for k, v := range tensors {
|
|
if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
|
|
if !yield(v) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}),
|
|
want: map[string]Layer{
|
|
"blk.0": {
|
|
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
|
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
|
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
|
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
|
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
|
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
|
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
|
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
|
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
|
},
|
|
"token_embd": {"weight": tensors["token_embd.weight"]},
|
|
"output_norm": {"weight": tensors["output_norm.weight"]},
|
|
},
|
|
},
|
|
{
|
|
name: "vision",
|
|
items: slices.Collect(func(yield func(*Tensor) bool) {
|
|
for k, v := range tensors {
|
|
if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
|
|
if !yield(v) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}),
|
|
want: map[string]Layer{
|
|
"mm.0": {
|
|
"bias": tensors["mm.0.bias"],
|
|
"weight": tensors["mm.0.weight"],
|
|
},
|
|
"v.blk.0": {
|
|
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
|
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
|
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
|
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
|
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
|
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
|
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
|
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
|
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
|
},
|
|
"v": {
|
|
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
|
"position_embd.gate": tensors["v.position_embd.gate"],
|
|
"position_embd.weight": tensors["v.position_embd.weight"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "vision and text",
|
|
items: slices.Collect(maps.Values(tensors)),
|
|
want: map[string]Layer{
|
|
"blk.0": {
|
|
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
|
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
|
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
|
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
|
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
|
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
|
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
|
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
|
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
|
},
|
|
"token_embd": {"weight": tensors["token_embd.weight"]},
|
|
"output_norm": {"weight": tensors["output_norm.weight"]},
|
|
"mm.0": {
|
|
"bias": tensors["mm.0.bias"],
|
|
"weight": tensors["mm.0.weight"],
|
|
},
|
|
"v.blk.0": {
|
|
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
|
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
|
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
|
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
|
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
|
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
|
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
|
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
|
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
|
},
|
|
"v": {
|
|
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
|
"position_embd.gate": tensors["v.position_embd.gate"],
|
|
"position_embd.weight": tensors["v.position_embd.weight"],
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := Tensors{items: tt.items}.GroupLayers()
|
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
|
t.Errorf("unexpected layers (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|