2025-05-01 08:59:31 +08:00
|
|
|
package ggml
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2025-06-17 01:42:32 +08:00
|
|
|
"math/rand/v2"
|
2025-05-01 08:59:31 +08:00
|
|
|
"os"
|
2025-06-17 01:42:32 +08:00
|
|
|
"strings"
|
2025-05-01 08:59:31 +08:00
|
|
|
"testing"
|
|
|
|
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestWriteGGUF(t *testing.T) {
|
2025-08-27 04:57:46 +08:00
|
|
|
b := bytes.NewBuffer(make([]byte, 2*3))
|
2025-06-17 01:42:32 +08:00
|
|
|
for range 8 {
|
|
|
|
t.Run("shuffle", func(t *testing.T) {
|
|
|
|
t.Parallel()
|
2025-05-01 08:59:31 +08:00
|
|
|
|
2025-06-17 01:42:32 +08:00
|
|
|
ts := []*Tensor{
|
2025-08-27 04:57:46 +08:00
|
|
|
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
|
|
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
|
|
|
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
2025-06-17 01:42:32 +08:00
|
|
|
}
|
2025-05-01 08:59:31 +08:00
|
|
|
|
2025-08-27 04:57:46 +08:00
|
|
|
rand.Shuffle(len(ts), func(i, j int) {
|
2025-06-17 01:42:32 +08:00
|
|
|
ts[i], ts[j] = ts[j], ts[i]
|
|
|
|
})
|
2025-05-01 08:59:31 +08:00
|
|
|
|
2025-06-17 01:42:32 +08:00
|
|
|
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
defer w.Close()
|
|
|
|
|
|
|
|
if err := WriteGGUF(w, KV{
|
|
|
|
"general.alignment": uint32(16),
|
|
|
|
}, ts); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
r, err := os.Open(w.Name())
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
defer r.Close()
|
|
|
|
|
|
|
|
ff, err := Decode(r, 0)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if diff := cmp.Diff(KV{
|
|
|
|
"general.alignment": uint32(16),
|
|
|
|
"general.parameter_count": uint64(54),
|
|
|
|
}, ff.KV()); diff != "" {
|
|
|
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
|
|
|
}
|
2025-05-01 08:59:31 +08:00
|
|
|
|
2025-06-17 01:42:32 +08:00
|
|
|
if diff := cmp.Diff(Tensors{
|
2025-08-27 04:57:46 +08:00
|
|
|
Offset: 592,
|
2025-06-17 01:42:32 +08:00
|
|
|
items: []*Tensor{
|
2025-08-27 04:57:46 +08:00
|
|
|
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
|
|
|
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
|
|
|
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
|
|
|
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
|
|
|
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
|
|
|
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
2025-06-17 01:42:32 +08:00
|
|
|
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
|
|
|
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
|
|
|
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
|
|
|
},
|
|
|
|
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
|
|
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
|
|
|
}
|
|
|
|
})
|
2025-05-01 08:59:31 +08:00
|
|
|
}
|
|
|
|
}
|