| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | package convert | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"encoding/json" | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	"fmt" | 
					
						
							| 
									
										
										
										
											2024-04-13 04:55:12 +08:00
										 |  |  | 	"io" | 
					
						
							| 
									
										
										
										
											2024-06-30 07:53:59 +08:00
										 |  |  | 	"io/fs" | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	"log/slog" | 
					
						
							| 
									
										
										
										
											2024-06-29 04:27:05 +08:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 04:04:17 +08:00
										 |  |  | 	"github.com/ollama/ollama/llm" | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | type ModelParameters struct { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	Architectures []string `json:"architectures"` | 
					
						
							|  |  |  | 	VocabSize     uint32   `json:"vocab_size"` | 
					
						
							| 
									
										
										
										
											2024-03-29 09:54:01 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | type AdapterParameters struct { | 
					
						
							|  |  |  | 	Alpha          uint32 `json:"lora_alpha"` | 
					
						
							|  |  |  | 	LoraLayers     uint32 `json:"lora_layers"` | 
					
						
							|  |  |  | 	LoraParameters struct { | 
					
						
							|  |  |  | 		Rank  uint32  `json:"rank"` | 
					
						
							|  |  |  | 		Alpha float32 `json:"alpha"` | 
					
						
							|  |  |  | 		Scale float32 `json:"scale"` | 
					
						
							|  |  |  | 	} `json:"lora_parameters"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (ModelParameters) KV(t *Tokenizer) llm.KV { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	kv := llm.KV{ | 
					
						
							|  |  |  | 		"general.file_type":            uint32(1), | 
					
						
							|  |  |  | 		"general.quantization_version": uint32(2), | 
					
						
							|  |  |  | 		"tokenizer.ggml.pre":           t.Pre, | 
					
						
							|  |  |  | 		"tokenizer.ggml.model":         t.Vocabulary.Model, | 
					
						
							|  |  |  | 		"tokenizer.ggml.tokens":        t.Vocabulary.Tokens, | 
					
						
							|  |  |  | 		"tokenizer.ggml.scores":        t.Vocabulary.Scores, | 
					
						
							|  |  |  | 		"tokenizer.ggml.token_type":    t.Vocabulary.Types, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-04 06:53:58 +08:00
										 |  |  | 	if len(t.Merges) > 0 { | 
					
						
							|  |  |  | 		kv["tokenizer.ggml.merges"] = t.Merges | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	if t.Template != "" { | 
					
						
							|  |  |  | 		kv["tokenizer.chat_template"] = t.Template | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-04-02 07:14:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	for _, sv := range t.SpecialVocabulary { | 
					
						
							|  |  |  | 		kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) | 
					
						
							|  |  |  | 		kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-04-16 02:26:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	return kv | 
					
						
							| 
									
										
										
										
											2024-04-02 07:14:53 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | func (p AdapterParameters) KV() llm.KV { | 
					
						
							|  |  |  | 	var alpha float32 | 
					
						
							|  |  |  | 	if p.LoraParameters.Alpha == 0 { | 
					
						
							|  |  |  | 		alpha = float32(p.Alpha) | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		alpha = p.LoraParameters.Alpha | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	kv := llm.KV{ | 
					
						
							|  |  |  | 		"adapter.lora.alpha": alpha, | 
					
						
							|  |  |  | 		"adapter.type":       "lora", | 
					
						
							|  |  |  | 		"general.file_type":  uint32(1), | 
					
						
							|  |  |  | 		"general.type":       "adapter", | 
					
						
							|  |  |  | 		"general.version":    "v0.2", | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return kv | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (ModelParameters) specialTokenTypes() []string { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	return []string{ | 
					
						
							|  |  |  | 		"bos", "eos", "unk", "sep", "pad", "cls", "mask", | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	return llm.WriteGGUF(ws, kv, ts) | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error { | 
					
						
							|  |  |  | 	return llm.WriteGGUF(ws, kv, ts) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ModelConverter interface { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	// KV maps parameters to LLM key-values
 | 
					
						
							|  |  |  | 	KV(*Tokenizer) llm.KV | 
					
						
							|  |  |  | 	// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
 | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 	Tensors([]Tensor) []llm.Tensor | 
					
						
							| 
									
										
										
										
											2024-06-29 04:27:05 +08:00
										 |  |  | 	// Replacements returns a list of string pairs to replace in tensor names.
 | 
					
						
							|  |  |  | 	// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
 | 
					
						
							|  |  |  | 	Replacements() []string | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 	// specialTokenTypes returns any special token types the model uses
 | 
					
						
							|  |  |  | 	specialTokenTypes() []string | 
					
						
							| 
									
										
										
										
											2024-06-29 04:27:05 +08:00
										 |  |  | 	// writeFile writes the model to the provided io.WriteSeeker
 | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 	writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-06 23:59:04 +08:00
										 |  |  | type moreParser interface { | 
					
						
							|  |  |  | 	parseMore(fs.FS) error | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | type AdapterConverter interface { | 
					
						
							|  |  |  | 	// KV maps parameters to LLM key-values
 | 
					
						
							|  |  |  | 	KV(llm.KV) llm.KV | 
					
						
							|  |  |  | 	// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
 | 
					
						
							|  |  |  | 	Tensors([]Tensor) []llm.Tensor | 
					
						
							|  |  |  | 	// Replacements returns a list of string pairs to replace in tensor names.
 | 
					
						
							|  |  |  | 	// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
 | 
					
						
							|  |  |  | 	Replacements() []string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error { | 
					
						
							|  |  |  | 	bts, err := fs.ReadFile(fsys, "adapter_config.json") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var p AdapterParameters | 
					
						
							|  |  |  | 	if err := json.Unmarshal(bts, &p); err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	arch, ok := baseKV["general.architecture"] | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return errors.New("architecture not set for the base model") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var conv AdapterConverter | 
					
						
							|  |  |  | 	switch arch { | 
					
						
							|  |  |  | 	case "llama": | 
					
						
							|  |  |  | 		conv = &llamaAdapter{} | 
					
						
							|  |  |  | 	case "gemma2": | 
					
						
							|  |  |  | 		conv = &gemma2Adapter{} | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return errors.New("unsupported architecture") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := json.Unmarshal(bts, conv); err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts)) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
 | 
					
						
							|  |  |  | // and files it finds in the input path.
 | 
					
						
							|  |  |  | // Supported input model formats include safetensors.
 | 
					
						
							|  |  |  | // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { | 
					
						
							| 
									
										
										
										
											2024-06-30 07:53:59 +08:00
										 |  |  | 	bts, err := fs.ReadFile(fsys, "config.json") | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 	var p ModelParameters | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 	if err := json.Unmarshal(bts, &p); err != nil { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	if len(p.Architectures) < 1 { | 
					
						
							|  |  |  | 		return errors.New("unknown architecture") | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 	var conv ModelConverter | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	switch p.Architectures[0] { | 
					
						
							|  |  |  | 	case "LlamaForCausalLM", "MistralForCausalLM": | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 		conv = &llamaModel{} | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	case "MixtralForCausalLM": | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 		conv = &mixtralModel{} | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	case "GemmaForCausalLM": | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 		conv = &gemmaModel{} | 
					
						
							| 
									
										
										
										
											2024-06-29 04:27:05 +08:00
										 |  |  | 	case "Gemma2ForCausalLM": | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 		conv = &gemma2Model{} | 
					
						
							| 
									
										
										
										
											2024-06-04 06:53:58 +08:00
										 |  |  | 	case "Phi3ForCausalLM": | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 		conv = &phi3Model{} | 
					
						
							| 
									
										
										
										
											2025-01-15 02:34:37 +08:00
										 |  |  | 	case "Qwen2ForCausalLM": | 
					
						
							|  |  |  | 		conv = &qwen2Model{} | 
					
						
							| 
									
										
										
										
											2024-06-06 23:59:04 +08:00
										 |  |  | 	case "BertModel": | 
					
						
							| 
									
										
										
										
											2024-08-24 02:29:56 +08:00
										 |  |  | 		conv = &bertModel{} | 
					
						
							| 
									
										
										
										
											2025-01-16 08:31:22 +08:00
										 |  |  | 	case "CohereForCausalLM": | 
					
						
							|  |  |  | 		conv = &commandrModel{} | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	default: | 
					
						
							|  |  |  | 		return errors.New("unsupported architecture") | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 	if err := json.Unmarshal(bts, conv); err != nil { | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-06 23:59:04 +08:00
										 |  |  | 	if t, ok := conv.(moreParser); ok { | 
					
						
							|  |  |  | 		if err := t.parseMore(fsys); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-30 07:53:59 +08:00
										 |  |  | 	t, err := parseTokenizer(fsys, conv.specialTokenTypes()) | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 08:18:54 +08:00
										 |  |  | 	vocabSize := int(p.VocabSize) | 
					
						
							|  |  |  | 	switch { | 
					
						
							|  |  |  | 	case vocabSize > len(t.Vocabulary.Tokens): | 
					
						
							|  |  |  | 		slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 		for i := range vocabSize - len(t.Vocabulary.Tokens) { | 
					
						
							|  |  |  | 			t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) | 
					
						
							|  |  |  | 			t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) | 
					
						
							|  |  |  | 			t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2024-09-10 08:18:54 +08:00
										 |  |  | 	case vocabSize < len(t.Vocabulary.Tokens): | 
					
						
							|  |  |  | 		return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize) | 
					
						
							|  |  |  | 	default: | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 		slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-29 04:27:05 +08:00
										 |  |  | 	ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...)) | 
					
						
							| 
									
										
										
										
											2024-06-01 11:00:49 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2024-03-29 09:54:01 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-09 07:59:48 +08:00
										 |  |  | 	return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts)) | 
					
						
							| 
									
										
										
										
											2024-03-07 13:01:51 +08:00
										 |  |  | } |