mirror of https://github.com/ollama/ollama.git
295 lines
8.9 KiB
Go
295 lines
8.9 KiB
Go
package convert
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/d4l3k/go-bfloat16"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/x448/float16"
|
|
)
|
|
|
|
func TestSafetensors(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
root, err := os.OpenRoot(t.TempDir())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer root.Close()
|
|
|
|
cases := []struct {
|
|
name,
|
|
dtype string
|
|
offset,
|
|
size int64
|
|
shape []uint64
|
|
setup func(*testing.T, *os.File)
|
|
want []byte
|
|
}{
|
|
{
|
|
name: "fp32-fp32",
|
|
dtype: "F32",
|
|
size: 32 * 4, // 32 floats, each 4 bytes
|
|
shape: []uint64{32},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
f32s := make([]float32, 32)
|
|
for i := range f32s {
|
|
f32s[i] = float32(i)
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
|
},
|
|
},
|
|
{
|
|
name: "fp32-fp16",
|
|
dtype: "F32",
|
|
size: 32 * 4, // 32 floats, each 4 bytes
|
|
shape: []uint64{16, 2},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
f32s := make([]float32, 32)
|
|
for i := range f32s {
|
|
f32s[i] = float32(i)
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
|
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
|
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
|
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
|
},
|
|
},
|
|
{
|
|
name: "fp16-fp16",
|
|
dtype: "F16",
|
|
size: 32 * 2, // 32 floats, each 2 bytes
|
|
shape: []uint64{16, 2},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
u16s := make([]uint16, 32)
|
|
for i := range u16s {
|
|
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
|
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
|
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
|
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
|
},
|
|
},
|
|
{
|
|
name: "fp16-fp32",
|
|
dtype: "F16",
|
|
size: 32 * 2, // 32 floats, each 2 bytes
|
|
shape: []uint64{32},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
u16s := make([]uint16, 32)
|
|
for i := range u16s {
|
|
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
|
},
|
|
},
|
|
{
|
|
name: "bf16-bf16",
|
|
dtype: "BF16",
|
|
size: 32 * 2, // 32 brain floats, each 2 bytes
|
|
shape: []uint64{16, 2},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
f32s := make([]float32, 32)
|
|
for i := range f32s {
|
|
f32s[i] = float32(i)
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40,
|
|
0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41,
|
|
0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41,
|
|
0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41,
|
|
},
|
|
},
|
|
{
|
|
name: "bf16-fp32",
|
|
dtype: "BF16",
|
|
size: 32 * 2, // 32 brain floats, each 2 bytes
|
|
shape: []uint64{32},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
f32s := make([]float32, 32)
|
|
for i := range f32s {
|
|
f32s[i] = float32(i)
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
|
},
|
|
},
|
|
{
|
|
name: "u8-u8",
|
|
dtype: "U8",
|
|
size: 32, // 32 brain floats, each 1 bytes
|
|
shape: []uint64{32},
|
|
setup: func(t *testing.T, f *os.File) {
|
|
u8s := make([]uint8, 32)
|
|
for i := range u8s {
|
|
u8s[i] = uint8(i)
|
|
}
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, u8s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
},
|
|
want: []byte{
|
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
path := filepath.Base(t.Name())
|
|
st := safetensor{
|
|
fs: root.FS(),
|
|
path: path,
|
|
dtype: tt.dtype,
|
|
offset: tt.offset,
|
|
size: tt.size,
|
|
tensorBase: &tensorBase{
|
|
name: tt.name,
|
|
shape: tt.shape,
|
|
},
|
|
}
|
|
|
|
f, err := root.Create(path)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
tt.setup(t, f)
|
|
|
|
var b bytes.Buffer
|
|
if _, err := st.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" {
|
|
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSafetensorKind(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
st safetensor
|
|
expected uint32
|
|
}{
|
|
{
|
|
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
|
|
st: safetensor{
|
|
tensorBase: &tensorBase{
|
|
name: "weight.matrix",
|
|
shape: []uint64{10, 10}, // will default to FP16
|
|
},
|
|
dtype: "BF16",
|
|
},
|
|
expected: tensorKindBF16,
|
|
},
|
|
{
|
|
name: "BF16 dtype with v. prefix should return base kind",
|
|
st: safetensor{
|
|
tensorBase: &tensorBase{
|
|
name: "v.weight.matrix",
|
|
shape: []uint64{10, 10}, // will default to FP16
|
|
},
|
|
dtype: "BF16",
|
|
},
|
|
expected: tensorKindFP16,
|
|
},
|
|
{
|
|
name: "BF16 dtype with FP32 base kind should return FP32",
|
|
st: safetensor{
|
|
tensorBase: &tensorBase{
|
|
name: "weight.matrix",
|
|
shape: []uint64{10}, // will default to FP32
|
|
},
|
|
dtype: "BF16",
|
|
},
|
|
expected: tensorKindFP32,
|
|
},
|
|
{
|
|
name: "Non-BF16 dtype should return base kind",
|
|
st: safetensor{
|
|
tensorBase: &tensorBase{
|
|
name: "weight.matrix",
|
|
shape: []uint64{10, 10}, // will default to FP16
|
|
},
|
|
dtype: "FP16",
|
|
},
|
|
expected: tensorKindFP16,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := tt.st.Kind()
|
|
if result != tt.expected {
|
|
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|