ggml: Seperate tensor load from backend creation

Currently, when the backend is created, the tensors are loaded at the
same time, which is a slow operation. This separates them to be two
steps:
 - Create backend, including enumerating tensors and memory allocation
 - Loading tensor data

This allows more flexibility in managing model loading.
This commit is contained in:
Jesse Gross 2025-04-17 13:42:40 -07:00 committed by Jesse Gross
parent d755577473
commit 94ab428e3f
13 changed files with 131 additions and 115 deletions

View File

@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, -1)
m, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}
@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, _, err := ggml.Decode(r, -1)
m, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}

View File

@ -15,6 +15,7 @@ import (
type GGML struct {
container
model
Length int64
}
type model interface {
@ -386,12 +387,12 @@ func DetectContentType(b []byte) string {
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, 0, err
return nil, err
}
var c container
@ -401,24 +402,25 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default:
return nil, 0, errors.New("invalid file magic")
return nil, errors.New("invalid file magic")
}
model, err := c.Decode(rs)
if err != nil {
return nil, 0, err
return nil, err
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return nil, 0, err
return nil, err
}
// final model type
return &GGML{
container: c,
model: model,
}, offset, nil
Length: offset,
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {

View File

@ -35,7 +35,7 @@ func TestWriteGGUF(t *testing.T) {
}
defer r.Close()
ff, _, err := Decode(r, 0)
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}

View File

@ -423,7 +423,7 @@ func projectorMemoryRequirements(filename string) (weights uint64) {
}
defer file.Close()
ggml, _, err := ggml.Decode(file, 1024)
ggml, err := ggml.Decode(file, 1024)
if err != nil {
return 0
}

View File

@ -121,7 +121,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
}
defer f.Close()
ggml, _, err := ggml.Decode(f, maxArraySize)
ggml, err := ggml.Decode(f, maxArraySize)
return ggml, err
}

View File

@ -6,7 +6,6 @@ import (
"encoding/binary"
"fmt"
"math"
"os"
"slices"
"strconv"
"strings"
@ -15,6 +14,7 @@ import (
)
type Backend interface {
Load(ctx context.Context, progress func(float32)) error
Config() fs.Config
Get(name string) Tensor
NewContext() Context
@ -52,10 +52,6 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@ -72,9 +68,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@ -82,9 +78,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam
backends[name] = f
}
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(ctx, f, params)
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")

View File

@ -44,8 +44,15 @@ func devices() []*C.struct_ggml_backend_device {
}
type Backend struct {
// modelPath is the location of the model data
modelPath string
meta *fsggml.GGML
// tensorLoadTargets maps from the name of the tensor in the file
// to the name that is used by the model definition
tensorLoadTargets map[string][]string
sched *C.struct_ggml_backend_sched
schedBackends []*C.struct_ggml_backend
schedBufts []*C.struct_ggml_backend_buffer_type
@ -64,8 +71,14 @@ type Backend struct {
maxGraphNodes int
}
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fsggml.Decode(r, -1)
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
@ -307,73 +320,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
}
}
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" {
target = t.Name
}
tt, ok := tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(r.Name())
if err != nil {
slog.Warn("file open error", "file", r.Name(), "error", err)
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
return err
}
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
slog.Warn("file read error", "file", r.Name(), "error", err)
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
// map devices to backend buffer types so new tensors can be assigned to the correct device
deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
@ -397,8 +343,10 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
return &Backend{
modelPath: modelPath,
flashAttention: params.FlashAttention,
meta: meta,
tensorLoadTargets: targets,
tensors: tensors,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
@ -426,6 +374,77 @@ func init() {
ml.RegisterBackend("ggml", New)
}
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
var doneBytes atomic.Uint64
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range b.meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
for i := range tts {
target := b.tensorLoadTargets[t.Name][i]
if target == "" {
target = t.Name
}
tt, ok := b.tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(b.modelPath)
if err != nil {
slog.Warn("file open error", "file", b.modelPath, "error", err)
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
return err
}
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
slog.Warn("file read error", "file", b.modelPath, "error", err)
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if progress != nil {
done := doneBytes.Add(uint64(n))
progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
return nil
}
func (b *Backend) Config() fs.Config {
return b.meta.KV()
}

View File

@ -98,14 +98,8 @@ func Register(name string, f func(fs.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(ctx, r, params)
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
if err != nil {
return nil, err
}
@ -134,7 +128,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
defer r.Close()
meta, _, err := fsggml.Decode(r, -1)
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}

View File

@ -845,7 +845,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(ctx, mpath, params)
s.model, err = model.New(mpath, params)
if err != nil {
panic(err)
}
@ -874,6 +874,14 @@ func (s *Server) loadModel(
panic(err)
}
err = s.model.Backend().Load(ctx,
func(progress float32) {
s.progress = progress
})
if err != nil {
panic(err)
}
s.status = llm.ServerStatusReady
s.ready.Done()
}
@ -928,9 +936,6 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,

View File

@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
}
defer bin.Close()
f, _, err := ggml.Decode(bin, -1)
f, err := ggml.Decode(bin, -1)
if err != nil {
return nil, err
}
@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
f, _, err := ggml.Decode(temp, 1024)
f, err := ggml.Decode(temp, 1024)
if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err
@ -508,7 +508,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var offset int64
for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 1024)
f, err := ggml.Decode(blob, 1024)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
@ -523,7 +523,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
}
var layer Layer
if digest != "" && n == stat.Size() && offset == 0 {
if digest != "" && f.Length == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
@ -533,14 +533,14 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
layer, err = NewLayer(io.NewSectionReader(blob, offset, f.Length), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f})
offset = n
offset = f.Length
}
return detectChatTemplate(layers)

View File

@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
if err == nil {
defer r.Close()
f, _, err := ggml.Decode(r, 1024)
f, err := ggml.Decode(r, 1024)
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)

View File

@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
defer blob.Close()
f, _, err := ggml.Decode(blob, -1)
f, err := ggml.Decode(blob, -1)
if err != nil {
return nil, err
}

View File

@ -271,7 +271,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatal(err.Error())
}
defer fp.Close()
meta, _, err := fsggml.Decode(fp, -1)
meta, err := fsggml.Decode(fp, -1)
if err != nil {
t.Fatal(err.Error())
}
@ -303,7 +303,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}
defer fpNew.Close()
newMeta, _, err := fsggml.Decode(fpNew, -1)
newMeta, err := fsggml.Decode(fpNew, -1)
if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}