| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | package model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"image" | 
					
						
							|  |  |  | 	_ "image/jpeg" | 
					
						
							|  |  |  | 	_ "image/png" | 
					
						
							|  |  |  | 	"log/slog" | 
					
						
							|  |  |  | 	"os" | 
					
						
							|  |  |  | 	"reflect" | 
					
						
							|  |  |  | 	"strconv" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	_ "golang.org/x/image/bmp" | 
					
						
							|  |  |  | 	_ "golang.org/x/image/tiff" | 
					
						
							|  |  |  | 	_ "golang.org/x/image/webp" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	"github.com/ollama/ollama/kvcache" | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	"github.com/ollama/ollama/ml" | 
					
						
							|  |  |  | 	_ "github.com/ollama/ollama/ml/backend" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Options struct { | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	Inputs    []int32 | 
					
						
							|  |  |  | 	Positions []int32 | 
					
						
							|  |  |  | 	Sequences []int | 
					
						
							|  |  |  | 	Outputs   []int32 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	Images []image.Image | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | type config struct { | 
					
						
							|  |  |  | 	Cache kvcache.Cache | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Base struct { | 
					
						
							|  |  |  | 	b ml.Backend | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	config | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *Base) Backend() ml.Backend { | 
					
						
							|  |  |  | 	return m.b | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | func (m *Base) Config() config { | 
					
						
							|  |  |  | 	return m.config | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | type Model interface { | 
					
						
							|  |  |  | 	Forward(ml.Context, Options) (ml.Tensor, error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	Backend() ml.Backend | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	Config() config | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var models = make(map[string]func(ml.Config) (Model, error)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func Register(name string, f func(ml.Config) (Model, error)) { | 
					
						
							|  |  |  | 	if _, ok := models[name]; ok { | 
					
						
							|  |  |  | 		panic("model: model already registered") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	models[name] = f | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func New(s string) (Model, error) { | 
					
						
							|  |  |  | 	r, err := os.Open(s) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer r.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	b, err := ml.NewBackend(r) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	arch := b.Config().Architecture() | 
					
						
							|  |  |  | 	f, ok := models[arch] | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("unsupported model architecture %q", arch) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	m, err := f(b.Config()) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	base := Base{b: b, config: m.Config()} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	v := reflect.ValueOf(m) | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	v.Elem().Set(populateFields(base, v.Elem())) | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	return m, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	t := v.Type() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if t.Kind() == reflect.Struct { | 
					
						
							|  |  |  | 		allNil := true | 
					
						
							|  |  |  | 		for i := range t.NumField() { | 
					
						
							|  |  |  | 			tt := t.Field(i).Type | 
					
						
							|  |  |  | 			vv := v.Field(i) | 
					
						
							|  |  |  | 			if !vv.CanSet() { | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// make a copy
 | 
					
						
							|  |  |  | 			tagsCopy := tags | 
					
						
							|  |  |  | 			if tag := t.Field(i).Tag.Get("gguf"); tag != "" { | 
					
						
							|  |  |  | 				tagsCopy = append(tagsCopy, ParseTags(tag)) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if tt == reflect.TypeOf((*Base)(nil)).Elem() { | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 				vv.Set(reflect.ValueOf(base)) | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 			} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { | 
					
						
							|  |  |  | 				var fn func([]Tag) [][]string | 
					
						
							|  |  |  | 				fn = func(tags []Tag) (values [][]string) { | 
					
						
							|  |  |  | 					if len(tags) < 1 { | 
					
						
							|  |  |  | 						return nil | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					values = [][]string{{tags[0].Name}} | 
					
						
							|  |  |  | 					for _, alt := range tags[0].Alternate { | 
					
						
							|  |  |  | 						values = append(values, []string{alt}) | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					for i, value := range values { | 
					
						
							|  |  |  | 						for _, rest := range fn(tags[1:]) { | 
					
						
							|  |  |  | 							value = append(value, rest...) | 
					
						
							|  |  |  | 						} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 						values[i] = value | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					return values | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				names := fn(tagsCopy) | 
					
						
							|  |  |  | 				for _, name := range names { | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 					if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 						slog.Debug("found tensor", "", tensor) | 
					
						
							|  |  |  | 						vv.Set(reflect.ValueOf(tensor)) | 
					
						
							|  |  |  | 						break | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				} | 
					
						
							| 
									
										
										
										
											2025-01-15 08:12:14 +08:00
										 |  |  | 			} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 				setPointer(base, vv, tagsCopy) | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 			} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { | 
					
						
							|  |  |  | 				for i := range vv.Len() { | 
					
						
							| 
									
										
										
										
											2025-01-15 08:12:14 +08:00
										 |  |  | 					vvv := vv.Index(i) | 
					
						
							|  |  |  | 					if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 						setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) | 
					
						
							| 
									
										
										
										
											2025-01-15 08:12:14 +08:00
										 |  |  | 					} else { | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 						vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) | 
					
						
							| 
									
										
										
										
											2025-01-15 08:12:14 +08:00
										 |  |  | 					} | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if !canNil(tt) || !vv.IsNil() { | 
					
						
							|  |  |  | 				allNil = false | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if allNil { | 
					
						
							|  |  |  | 			return reflect.Zero(t) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return v | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | func setPointer(base Base, v reflect.Value, tags []Tag) { | 
					
						
							| 
									
										
										
										
											2025-01-15 08:12:14 +08:00
										 |  |  | 	vv := v | 
					
						
							|  |  |  | 	if v.Kind() == reflect.Interface { | 
					
						
							|  |  |  | 		if v.IsNil() { | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		vv = vv.Elem() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	vv = vv.Elem() | 
					
						
							|  |  |  | 	if v.IsNil() { | 
					
						
							|  |  |  | 		vv = reflect.New(v.Type().Elem()).Elem() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	if f := populateFields(base, vv, tags...); f.CanAddr() { | 
					
						
							| 
									
										
										
										
											2025-01-15 08:12:14 +08:00
										 |  |  | 		v.Set(f.Addr()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | type Tag struct { | 
					
						
							|  |  |  | 	Name      string | 
					
						
							|  |  |  | 	Alternate []string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func ParseTags(s string) (tag Tag) { | 
					
						
							|  |  |  | 	parts := strings.Split(s, ",") | 
					
						
							|  |  |  | 	if len(parts) > 0 { | 
					
						
							|  |  |  | 		tag.Name = parts[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		for _, part := range parts[1:] { | 
					
						
							|  |  |  | 			if value, ok := strings.CutPrefix(part, "alt:"); ok { | 
					
						
							|  |  |  | 				tag.Alternate = append(tag.Alternate, value) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func canNil(t reflect.Type) bool { | 
					
						
							|  |  |  | 	return t.Kind() == reflect.Chan || | 
					
						
							|  |  |  | 		t.Kind() == reflect.Func || | 
					
						
							|  |  |  | 		t.Kind() == reflect.Interface || | 
					
						
							|  |  |  | 		t.Kind() == reflect.Map || | 
					
						
							|  |  |  | 		t.Kind() == reflect.Pointer || | 
					
						
							|  |  |  | 		t.Kind() == reflect.Slice | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { | 
					
						
							|  |  |  | 	if len(opts.Positions) != len(opts.Sequences) { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(opts.Positions) < 1 { | 
					
						
							|  |  |  | 		return nil, errors.New("batch size cannot be less than 1") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	cache := m.Config().Cache | 
					
						
							|  |  |  | 	if cache != nil { | 
					
						
							|  |  |  | 		err := cache.StartForward(ctx, opts.Positions, opts.Sequences) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	t, err := m.Forward(ctx, opts) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-04 11:35:12 +08:00
										 |  |  | 	ctx.Forward(t) | 
					
						
							|  |  |  | 	ctx.Compute(t) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return t, nil | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | } |