mirror of https://github.com/ollama/ollama.git
convert: skip reading into memory when possible (#11507)
if there's no transformation to the tensor and the input and output types match, copy directly into the writer. also read from a bufio with a 32K buffer
This commit is contained in:
parent
939fe69cd0
commit
8894029077
|
@ -1,6 +1,7 @@
|
||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -124,26 +125,41 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if seeker, ok := f.(io.Seeker); ok {
|
r, err := func() (io.Reader, error) {
|
||||||
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
if readerAt, ok := f.(io.ReaderAt); ok {
|
||||||
return 0, err
|
return io.NewSectionReader(readerAt, st.offset, st.size), nil
|
||||||
}
|
} else if seeker, ok := f.(io.Seeker); ok {
|
||||||
} else {
|
_, err := seeker.Seek(st.offset, io.SeekStart)
|
||||||
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
return f, err
|
||||||
return 0, err
|
} else {
|
||||||
|
_, err := io.CopyN(io.Discard, f, st.offset)
|
||||||
|
return f, err
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
br := bufio.NewReaderSize(r, min(32<<10, int(st.size)))
|
||||||
|
// special case when input and output are same type and the
|
||||||
|
// tensor doesn't need repacking
|
||||||
|
if (st.repacker == nil) &&
|
||||||
|
((st.dtype == "F32" && st.Kind() == tensorKindFP32) ||
|
||||||
|
(st.dtype == "F16" && st.Kind() == tensorKindFP16) ||
|
||||||
|
(st.dtype == "U8")) {
|
||||||
|
return io.CopyN(w, br, st.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
switch st.dtype {
|
switch st.dtype {
|
||||||
case "F32":
|
case "F32":
|
||||||
f32s = make([]float32, st.size/4)
|
f32s = make([]float32, st.size/4)
|
||||||
if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
|
if err = binary.Read(br, binary.LittleEndian, f32s); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
case "F16":
|
case "F16":
|
||||||
u16s := make([]uint16, st.size/2)
|
u16s := make([]uint16, st.size/2)
|
||||||
if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
|
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,14 +170,11 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
|
||||||
case "BF16":
|
case "BF16":
|
||||||
u8s := make([]uint8, st.size)
|
u8s := make([]uint8, st.size)
|
||||||
if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
|
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s = bfloat16.DecodeFloat32(u8s)
|
f32s = bfloat16.DecodeFloat32(u8s)
|
||||||
case "U8":
|
|
||||||
// U8 tensors do not support repacking or type conversion.
|
|
||||||
return io.CopyN(w, f, st.size)
|
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,232 @@
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/d4l3k/go-bfloat16"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/x448/float16"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSafetensors(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
root, err := os.OpenRoot(t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer root.Close()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name,
|
||||||
|
dtype string
|
||||||
|
offset,
|
||||||
|
size int64
|
||||||
|
shape []uint64
|
||||||
|
setup func(*testing.T, *os.File)
|
||||||
|
want []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fp32-fp32",
|
||||||
|
dtype: "F32",
|
||||||
|
size: 32 * 4, // 32 floats, each 4 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||||
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||||
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||||
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||||
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||||
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||||
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||||
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fp32-fp16",
|
||||||
|
dtype: "F32",
|
||||||
|
size: 32 * 4, // 32 floats, each 4 bytes
|
||||||
|
shape: []uint64{16, 2},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||||
|
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||||
|
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||||
|
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fp16-fp16",
|
||||||
|
dtype: "F16",
|
||||||
|
size: 32 * 2, // 32 floats, each 2 bytes
|
||||||
|
shape: []uint64{16, 2},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
u16s := make([]uint16, 32)
|
||||||
|
for i := range u16s {
|
||||||
|
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||||
|
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||||
|
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||||
|
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fp16-fp32",
|
||||||
|
dtype: "F16",
|
||||||
|
size: 32 * 2, // 32 floats, each 2 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
u16s := make([]uint16, 32)
|
||||||
|
for i := range u16s {
|
||||||
|
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||||
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||||
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||||
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||||
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||||
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||||
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||||
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bf16-bf16",
|
||||||
|
dtype: "BF16",
|
||||||
|
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||||
|
shape: []uint64{16, 2},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40,
|
||||||
|
0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41,
|
||||||
|
0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41,
|
||||||
|
0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bf16-fp32",
|
||||||
|
dtype: "BF16",
|
||||||
|
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||||
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||||
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||||
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||||
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||||
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||||
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||||
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "u8-u8",
|
||||||
|
dtype: "U8",
|
||||||
|
size: 32, // 32 brain floats, each 1 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
u8s := make([]uint8, 32)
|
||||||
|
for i := range u8s {
|
||||||
|
u8s[i] = uint8(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, u8s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := filepath.Base(t.Name())
|
||||||
|
st := safetensor{
|
||||||
|
fs: root.FS(),
|
||||||
|
path: path,
|
||||||
|
dtype: tt.dtype,
|
||||||
|
offset: tt.offset,
|
||||||
|
size: tt.size,
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: tt.name,
|
||||||
|
shape: tt.shape,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := root.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tt.setup(t, f)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := st.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" {
|
||||||
|
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue