| 
									
										
										
										
											2024-05-21 02:26:45 +08:00
										 |  |  | package parser | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"bufio" | 
					
						
							| 
									
										
										
										
											2023-07-26 01:22:23 +08:00
										 |  |  | 	"bytes" | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	"crypto/sha256" | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2023-07-28 00:55:48 +08:00
										 |  |  | 	"fmt" | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 	"io" | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"os" | 
					
						
							|  |  |  | 	"os/user" | 
					
						
							|  |  |  | 	"path/filepath" | 
					
						
							| 
									
										
										
										
											2025-04-05 08:33:07 +08:00
										 |  |  | 	"runtime" | 
					
						
							| 
									
										
										
										
											2025-01-09 03:22:01 +08:00
										 |  |  | 	"slices" | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	"strconv" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2025-04-05 08:33:07 +08:00
										 |  |  | 	"sync" | 
					
						
							| 
									
										
										
										
											2024-06-14 02:39:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-05 08:33:07 +08:00
										 |  |  | 	"golang.org/x/sync/errgroup" | 
					
						
							| 
									
										
										
										
											2024-06-14 02:39:01 +08:00
										 |  |  | 	"golang.org/x/text/encoding/unicode" | 
					
						
							|  |  |  | 	"golang.org/x/text/transform" | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/ollama/ollama/api" | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | var ErrModelNotFound = errors.New("no Modelfile or safetensors files found") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Modelfile struct { | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	Commands []Command | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | func (f Modelfile) String() string { | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	var sb strings.Builder | 
					
						
							|  |  |  | 	for _, cmd := range f.Commands { | 
					
						
							|  |  |  | 		fmt.Fprintln(&sb, cmd.String()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return sb.String() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-06 05:54:40 +08:00
										 |  |  | var deprecatedParameters = []string{ | 
					
						
							|  |  |  | 	"penalize_newline", | 
					
						
							|  |  |  | 	"low_vram", | 
					
						
							|  |  |  | 	"f16_kv", | 
					
						
							|  |  |  | 	"logits_all", | 
					
						
							|  |  |  | 	"vocab_only", | 
					
						
							|  |  |  | 	"use_mlock", | 
					
						
							| 
									
										
										
										
											2025-05-08 23:31:08 +08:00
										 |  |  | 	"mirostat", | 
					
						
							|  |  |  | 	"mirostat_tau", | 
					
						
							|  |  |  | 	"mirostat_eta", | 
					
						
							| 
									
										
										
										
											2025-05-06 05:54:40 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2025-01-09 03:22:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | // CreateRequest creates a new *api.CreateRequest from an existing Modelfile
 | 
					
						
							| 
									
										
										
										
											2025-01-11 08:14:08 +08:00
										 |  |  | func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) { | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	req := &api.CreateRequest{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var messages []api.Message | 
					
						
							|  |  |  | 	var licenses []string | 
					
						
							|  |  |  | 	params := make(map[string]any) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, c := range f.Commands { | 
					
						
							|  |  |  | 		switch c.Name { | 
					
						
							|  |  |  | 		case "model": | 
					
						
							| 
									
										
										
										
											2025-01-11 08:14:08 +08:00
										 |  |  | 			path, err := expandPath(c.Args, relativeDir) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			digestMap, err := fileDigestMap(path) | 
					
						
							|  |  |  | 			if errors.Is(err, os.ErrNotExist) { | 
					
						
							|  |  |  | 				req.From = c.Args | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} else if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-16 16:14:04 +08:00
										 |  |  | 			if req.Files == nil { | 
					
						
							|  |  |  | 				req.Files = digestMap | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				for k, v := range digestMap { | 
					
						
							|  |  |  | 					req.Files[k] = v | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 		case "adapter": | 
					
						
							| 
									
										
										
										
											2025-01-11 08:14:08 +08:00
										 |  |  | 			path, err := expandPath(c.Args, relativeDir) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			digestMap, err := fileDigestMap(path) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			req.Adapters = digestMap | 
					
						
							|  |  |  | 		case "template": | 
					
						
							|  |  |  | 			req.Template = c.Args | 
					
						
							|  |  |  | 		case "system": | 
					
						
							|  |  |  | 			req.System = c.Args | 
					
						
							|  |  |  | 		case "license": | 
					
						
							|  |  |  | 			licenses = append(licenses, c.Args) | 
					
						
							|  |  |  | 		case "message": | 
					
						
							|  |  |  | 			role, msg, _ := strings.Cut(c.Args, ": ") | 
					
						
							|  |  |  | 			messages = append(messages, api.Message{Role: role, Content: msg}) | 
					
						
							|  |  |  | 		default: | 
					
						
							| 
									
										
										
										
											2025-01-09 03:22:01 +08:00
										 |  |  | 			if slices.Contains(deprecatedParameters, c.Name) { | 
					
						
							|  |  |  | 				fmt.Printf("warning: parameter %s is deprecated\n", c.Name) | 
					
						
							|  |  |  | 				break | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 			ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			for k, v := range ps { | 
					
						
							|  |  |  | 				if ks, ok := params[k].([]string); ok { | 
					
						
							|  |  |  | 					params[k] = append(ks, v.([]string)...) | 
					
						
							|  |  |  | 				} else if vs, ok := v.([]string); ok { | 
					
						
							|  |  |  | 					params[k] = vs | 
					
						
							|  |  |  | 				} else { | 
					
						
							|  |  |  | 					params[k] = v | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(params) > 0 { | 
					
						
							|  |  |  | 		req.Parameters = params | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if len(messages) > 0 { | 
					
						
							|  |  |  | 		req.Messages = messages | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if len(licenses) > 0 { | 
					
						
							|  |  |  | 		req.License = licenses | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return req, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func fileDigestMap(path string) (map[string]string, error) { | 
					
						
							|  |  |  | 	fl := make(map[string]string) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	fi, err := os.Stat(path) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var files []string | 
					
						
							|  |  |  | 	if fi.IsDir() { | 
					
						
							| 
									
										
										
										
											2025-05-06 02:59:26 +08:00
										 |  |  | 		fs, err := filesForModel(path) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-05-06 02:59:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		for _, f := range fs { | 
					
						
							|  |  |  | 			f, err := filepath.EvalSymlinks(f) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			rel, err := filepath.Rel(path, f) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if !filepath.IsLocal(rel) { | 
					
						
							|  |  |  | 				return nil, fmt.Errorf("insecure path: %s", rel) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			files = append(files, f) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	} else { | 
					
						
							|  |  |  | 		files = []string{path} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-05 08:33:07 +08:00
										 |  |  | 	var mu sync.Mutex | 
					
						
							|  |  |  | 	var g errgroup.Group | 
					
						
							|  |  |  | 	g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	for _, f := range files { | 
					
						
							| 
									
										
										
										
											2025-04-05 08:33:07 +08:00
										 |  |  | 		g.Go(func() error { | 
					
						
							|  |  |  | 			digest, err := digestForFile(f) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			mu.Lock() | 
					
						
							|  |  |  | 			defer mu.Unlock() | 
					
						
							|  |  |  | 			fl[f] = digest | 
					
						
							|  |  |  | 			return nil | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := g.Wait(); err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return fl, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func digestForFile(filename string) (string, error) { | 
					
						
							|  |  |  | 	filepath, err := filepath.EvalSymlinks(filename) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return "", err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	bin, err := os.Open(filepath) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return "", err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer bin.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	hash := sha256.New() | 
					
						
							|  |  |  | 	if _, err := io.Copy(hash, bin); err != nil { | 
					
						
							|  |  |  | 		return "", err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func filesForModel(path string) ([]string, error) { | 
					
						
							|  |  |  | 	detectContentType := func(path string) (string, error) { | 
					
						
							|  |  |  | 		f, err := os.Open(path) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return "", err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		defer f.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		var b bytes.Buffer | 
					
						
							|  |  |  | 		b.Grow(512) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { | 
					
						
							|  |  |  | 			return "", err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") | 
					
						
							|  |  |  | 		return contentType, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	glob := func(pattern, contentType string) ([]string, error) { | 
					
						
							|  |  |  | 		matches, err := filepath.Glob(pattern) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-06 02:59:26 +08:00
										 |  |  | 		for _, match := range matches { | 
					
						
							|  |  |  | 			if ct, err := detectContentType(match); err != nil { | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} else if ct != contentType { | 
					
						
							| 
									
										
										
										
											2025-05-06 02:59:26 +08:00
										 |  |  | 				return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return matches, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var files []string | 
					
						
							| 
									
										
										
										
											2025-03-15 07:56:32 +08:00
										 |  |  | 	if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 		// safetensors files might be unresolved git lfs references; skip if they are
 | 
					
						
							|  |  |  | 		// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
 | 
					
						
							|  |  |  | 		files = append(files, st...) | 
					
						
							|  |  |  | 	} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { | 
					
						
							|  |  |  | 		// pytorch files might also be unresolved git lfs references; skip if they are
 | 
					
						
							|  |  |  | 		// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
 | 
					
						
							|  |  |  | 		files = append(files, pt...) | 
					
						
							|  |  |  | 	} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { | 
					
						
							|  |  |  | 		// pytorch files might also be unresolved git lfs references; skip if they are
 | 
					
						
							|  |  |  | 		// covers consolidated.x.pth, consolidated.pth
 | 
					
						
							|  |  |  | 		files = append(files, pt...) | 
					
						
							|  |  |  | 	} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 { | 
					
						
							|  |  |  | 		// covers gguf files ending in .gguf
 | 
					
						
							|  |  |  | 		files = append(files, gg...) | 
					
						
							|  |  |  | 	} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 { | 
					
						
							|  |  |  | 		// covers gguf files ending in .bin
 | 
					
						
							|  |  |  | 		files = append(files, gg...) | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		return nil, ErrModelNotFound | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// add configuration files, json files are detected as text/plain
 | 
					
						
							|  |  |  | 	js, err := glob(filepath.Join(path, "*.json"), "text/plain") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	files = append(files, js...) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// bert models require a nested config.json
 | 
					
						
							|  |  |  | 	// TODO(mxyng): merge this with the glob above
 | 
					
						
							|  |  |  | 	js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	files = append(files, js...) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-12 03:10:35 +08:00
										 |  |  | 	// only include tokenizer.model is tokenizer.json is not present
 | 
					
						
							|  |  |  | 	if !slices.ContainsFunc(files, func(s string) bool { | 
					
						
							|  |  |  | 		return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") | 
					
						
							|  |  |  | 	}) { | 
					
						
							|  |  |  | 		if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { | 
					
						
							|  |  |  | 			// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
 | 
					
						
							|  |  |  | 			// tokenizer.model might be a unresolved git lfs reference; error if it is
 | 
					
						
							|  |  |  | 			files = append(files, tks...) | 
					
						
							|  |  |  | 		} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { | 
					
						
							|  |  |  | 			// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
 | 
					
						
							|  |  |  | 			files = append(files, tks...) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return files, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | type Command struct { | 
					
						
							|  |  |  | 	Name string | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 	Args string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | func (c Command) String() string { | 
					
						
							| 
									
										
										
										
											2024-05-02 01:01:09 +08:00
										 |  |  | 	var sb strings.Builder | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	switch c.Name { | 
					
						
							|  |  |  | 	case "model": | 
					
						
							| 
									
										
										
										
											2024-05-02 01:01:09 +08:00
										 |  |  | 		fmt.Fprintf(&sb, "FROM %s", c.Args) | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	case "license", "template", "system", "adapter": | 
					
						
							| 
									
										
										
										
											2024-05-02 01:01:09 +08:00
										 |  |  | 		fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	case "message": | 
					
						
							|  |  |  | 		role, message, _ := strings.Cut(c.Args, ": ") | 
					
						
							| 
									
										
										
										
											2024-05-02 01:01:09 +08:00
										 |  |  | 		fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message)) | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	default: | 
					
						
							| 
									
										
										
										
											2024-05-02 01:01:09 +08:00
										 |  |  | 		fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args)) | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-02 01:01:09 +08:00
										 |  |  | 	return sb.String() | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | type state int | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | const ( | 
					
						
							|  |  |  | 	stateNil state = iota | 
					
						
							|  |  |  | 	stateName | 
					
						
							|  |  |  | 	stateValue | 
					
						
							|  |  |  | 	stateParameter | 
					
						
							|  |  |  | 	stateMessage | 
					
						
							|  |  |  | 	stateComment | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 07:12:56 +08:00
										 |  |  | var ( | 
					
						
							| 
									
										
										
										
											2025-03-22 03:38:09 +08:00
										 |  |  | 	errMissingFrom        = errors.New("no FROM line") | 
					
						
							|  |  |  | 	errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") | 
					
						
							|  |  |  | 	errInvalidCommand     = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") | 
					
						
							| 
									
										
										
										
											2024-04-25 07:12:56 +08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 05:59:44 +08:00
										 |  |  | type ParserError struct { | 
					
						
							|  |  |  | 	LineNumber int | 
					
						
							|  |  |  | 	Msg        string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (e *ParserError) Error() string { | 
					
						
							|  |  |  | 	if e.LineNumber > 0 { | 
					
						
							|  |  |  | 		return fmt.Sprintf("(line %d): %s", e.LineNumber, e.Msg) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return e.Msg | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | func ParseFile(r io.Reader) (*Modelfile, error) { | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	var cmd Command | 
					
						
							|  |  |  | 	var curr state | 
					
						
							| 
									
										
										
										
											2024-11-15 05:59:44 +08:00
										 |  |  | 	var currLine int = 1 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	var b bytes.Buffer | 
					
						
							|  |  |  | 	var role string | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	var f Modelfile | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-14 02:39:01 +08:00
										 |  |  | 	tr := unicode.BOMOverride(unicode.UTF8.NewDecoder()) | 
					
						
							|  |  |  | 	br := bufio.NewReader(transform.NewReader(r, tr)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-14 01:22:16 +08:00
										 |  |  | 	for { | 
					
						
							|  |  |  | 		r, _, err := br.ReadRune() | 
					
						
							|  |  |  | 		if errors.Is(err, io.EOF) { | 
					
						
							|  |  |  | 			break | 
					
						
							|  |  |  | 		} else if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 05:59:44 +08:00
										 |  |  | 		if isNewline(r) { | 
					
						
							|  |  |  | 			currLine++ | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 		next, r, err := parseRuneForState(r, curr) | 
					
						
							|  |  |  | 		if errors.Is(err, io.ErrUnexpectedEOF) { | 
					
						
							|  |  |  | 			return nil, fmt.Errorf("%w: %s", err, b.String()) | 
					
						
							|  |  |  | 		} else if err != nil { | 
					
						
							| 
									
										
										
										
											2024-11-15 05:59:44 +08:00
										 |  |  | 			return nil, &ParserError{ | 
					
						
							|  |  |  | 				LineNumber: currLine, | 
					
						
							|  |  |  | 				Msg:        err.Error(), | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-27 06:13:27 +08:00
										 |  |  | 		// process the state transition, some transitions need to be intercepted and redirected
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 		if next != curr { | 
					
						
							|  |  |  | 			switch curr { | 
					
						
							| 
									
										
										
										
											2024-04-27 08:11:47 +08:00
										 |  |  | 			case stateName: | 
					
						
							|  |  |  | 				if !isValidCommand(b.String()) { | 
					
						
							| 
									
										
										
										
											2024-11-15 05:59:44 +08:00
										 |  |  | 					return nil, &ParserError{ | 
					
						
							|  |  |  | 						LineNumber: currLine, | 
					
						
							|  |  |  | 						Msg:        errInvalidCommand.Error(), | 
					
						
							|  |  |  | 					} | 
					
						
							| 
									
										
										
										
											2024-04-27 08:11:47 +08:00
										 |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-27 06:13:27 +08:00
										 |  |  | 				// next state sometimes depends on the current buffer value
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 				switch s := strings.ToLower(b.String()); s { | 
					
						
							|  |  |  | 				case "from": | 
					
						
							|  |  |  | 					cmd.Name = "model" | 
					
						
							|  |  |  | 				case "parameter": | 
					
						
							| 
									
										
										
										
											2024-04-27 06:13:27 +08:00
										 |  |  | 					// transition to stateParameter which sets command name
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 					next = stateParameter | 
					
						
							|  |  |  | 				case "message": | 
					
						
							| 
									
										
										
										
											2024-04-27 06:13:27 +08:00
										 |  |  | 					// transition to stateMessage which validates the message role
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 					next = stateMessage | 
					
						
							|  |  |  | 					fallthrough | 
					
						
							|  |  |  | 				default: | 
					
						
							|  |  |  | 					cmd.Name = s | 
					
						
							|  |  |  | 				} | 
					
						
							| 
									
										
										
										
											2024-04-27 08:11:47 +08:00
										 |  |  | 			case stateParameter: | 
					
						
							|  |  |  | 				cmd.Name = b.String() | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 			case stateMessage: | 
					
						
							| 
									
										
										
										
											2025-03-22 03:38:09 +08:00
										 |  |  | 				if !isValidMessageRole(b.String()) { | 
					
						
							|  |  |  | 					return nil, &ParserError{ | 
					
						
							|  |  |  | 						LineNumber: currLine, | 
					
						
							|  |  |  | 						Msg:        errInvalidMessageRole.Error(), | 
					
						
							|  |  |  | 					} | 
					
						
							| 
									
										
										
										
											2025-03-21 04:11:17 +08:00
										 |  |  | 				} | 
					
						
							| 
									
										
										
										
											2025-03-22 03:38:09 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 				role = b.String() | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 			case stateComment, stateNil: | 
					
						
							|  |  |  | 				// pass
 | 
					
						
							|  |  |  | 			case stateValue: | 
					
						
							| 
									
										
										
										
											2024-06-28 02:18:38 +08:00
										 |  |  | 				s, ok := unquote(strings.TrimSpace(b.String())) | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 				if !ok || isSpace(r) { | 
					
						
							|  |  |  | 					if _, err := b.WriteRune(r); err != nil { | 
					
						
							|  |  |  | 						return nil, err | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					continue | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				if role != "" { | 
					
						
							|  |  |  | 					s = role + ": " + s | 
					
						
							|  |  |  | 					role = "" | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				cmd.Args = s | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 				f.Commands = append(f.Commands, cmd) | 
					
						
							| 
									
										
										
										
											2023-08-11 07:09:02 +08:00
										 |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 			b.Reset() | 
					
						
							|  |  |  | 			curr = next | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if strconv.IsPrint(r) { | 
					
						
							|  |  |  | 			if _, err := b.WriteRune(r); err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							| 
									
										
										
										
											2023-08-11 07:22:08 +08:00
										 |  |  | 			} | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// flush the buffer
 | 
					
						
							|  |  |  | 	switch curr { | 
					
						
							|  |  |  | 	case stateComment, stateNil: | 
					
						
							|  |  |  | 		// pass; nothing to flush
 | 
					
						
							|  |  |  | 	case stateValue: | 
					
						
							| 
									
										
										
										
											2024-06-28 02:18:38 +08:00
										 |  |  | 		s, ok := unquote(strings.TrimSpace(b.String())) | 
					
						
							| 
									
										
										
										
											2024-04-25 10:17:26 +08:00
										 |  |  | 		if !ok { | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 			return nil, io.ErrUnexpectedEOF | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 10:17:26 +08:00
										 |  |  | 		if role != "" { | 
					
						
							|  |  |  | 			s = role + ": " + s | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		cmd.Args = s | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 		f.Commands = append(f.Commands, cmd) | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	default: | 
					
						
							|  |  |  | 		return nil, io.ErrUnexpectedEOF | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 	for _, cmd := range f.Commands { | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 		if cmd.Name == "model" { | 
					
						
							| 
									
										
										
										
											2024-05-01 01:55:19 +08:00
										 |  |  | 			return &f, nil | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 07:12:56 +08:00
										 |  |  | 	return nil, errMissingFrom | 
					
						
							| 
									
										
										
										
											2023-07-17 08:02:22 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | func parseRuneForState(r rune, cs state) (state, rune, error) { | 
					
						
							|  |  |  | 	switch cs { | 
					
						
							|  |  |  | 	case stateNil: | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case r == '#': | 
					
						
							|  |  |  | 			return stateComment, 0, nil | 
					
						
							|  |  |  | 		case isSpace(r), isNewline(r): | 
					
						
							|  |  |  | 			return stateNil, 0, nil | 
					
						
							|  |  |  | 		default: | 
					
						
							|  |  |  | 			return stateName, r, nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	case stateName: | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case isAlpha(r): | 
					
						
							|  |  |  | 			return stateName, r, nil | 
					
						
							|  |  |  | 		case isSpace(r): | 
					
						
							|  |  |  | 			return stateValue, 0, nil | 
					
						
							|  |  |  | 		default: | 
					
						
							| 
									
										
										
										
											2024-04-27 08:11:47 +08:00
										 |  |  | 			return stateNil, 0, errInvalidCommand | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	case stateValue: | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case isNewline(r): | 
					
						
							|  |  |  | 			return stateNil, r, nil | 
					
						
							|  |  |  | 		case isSpace(r): | 
					
						
							|  |  |  | 			return stateNil, r, nil | 
					
						
							|  |  |  | 		default: | 
					
						
							|  |  |  | 			return stateValue, r, nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	case stateParameter: | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case isAlpha(r), isNumber(r), r == '_': | 
					
						
							|  |  |  | 			return stateParameter, r, nil | 
					
						
							|  |  |  | 		case isSpace(r): | 
					
						
							|  |  |  | 			return stateValue, 0, nil | 
					
						
							|  |  |  | 		default: | 
					
						
							|  |  |  | 			return stateNil, 0, io.ErrUnexpectedEOF | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	case stateMessage: | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case isAlpha(r): | 
					
						
							|  |  |  | 			return stateMessage, r, nil | 
					
						
							|  |  |  | 		case isSpace(r): | 
					
						
							|  |  |  | 			return stateValue, 0, nil | 
					
						
							|  |  |  | 		default: | 
					
						
							|  |  |  | 			return stateNil, 0, io.ErrUnexpectedEOF | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	case stateComment: | 
					
						
							|  |  |  | 		switch { | 
					
						
							|  |  |  | 		case isNewline(r): | 
					
						
							|  |  |  | 			return stateNil, 0, nil | 
					
						
							|  |  |  | 		default: | 
					
						
							|  |  |  | 			return stateComment, 0, nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return stateNil, 0, errors.New("") | 
					
						
							| 
									
										
										
										
											2023-07-28 00:55:48 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2023-07-28 00:55:48 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 09:49:14 +08:00
										 |  |  | func quote(s string) string { | 
					
						
							| 
									
										
										
										
											2024-04-27 07:59:31 +08:00
										 |  |  | 	if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") { | 
					
						
							| 
									
										
										
										
											2024-04-25 09:49:14 +08:00
										 |  |  | 		if strings.Contains(s, "\"") { | 
					
						
							|  |  |  | 			return `"""` + s + `"""` | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-27 07:59:31 +08:00
										 |  |  | 		return `"` + s + `"` | 
					
						
							| 
									
										
										
										
											2024-04-25 09:49:14 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return s | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | func unquote(s string) (string, bool) { | 
					
						
							|  |  |  | 	// TODO: single quotes
 | 
					
						
							|  |  |  | 	if len(s) >= 3 && s[:3] == `"""` { | 
					
						
							|  |  |  | 		if len(s) >= 6 && s[len(s)-3:] == `"""` { | 
					
						
							|  |  |  | 			return s[3 : len(s)-3], true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return "", false | 
					
						
							| 
									
										
										
										
											2023-07-28 00:55:48 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	if len(s) >= 1 && s[0] == '"' { | 
					
						
							|  |  |  | 		if len(s) >= 2 && s[len(s)-1] == '"' { | 
					
						
							|  |  |  | 			return s[1 : len(s)-1], true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return "", false | 
					
						
							| 
									
										
										
										
											2023-07-28 00:55:48 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	return s, true | 
					
						
							| 
									
										
										
										
											2023-07-28 00:55:48 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | func isAlpha(r rune) bool { | 
					
						
							|  |  |  | 	return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | func isNumber(r rune) bool { | 
					
						
							|  |  |  | 	return r >= '0' && r <= '9' | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | func isSpace(r rune) bool { | 
					
						
							|  |  |  | 	return r == ' ' || r == '\t' | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-07-26 02:50:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | func isNewline(r rune) bool { | 
					
						
							|  |  |  | 	return r == '\r' || r == '\n' | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-22 03:38:09 +08:00
										 |  |  | func isValidMessageRole(role string) bool { | 
					
						
							| 
									
										
										
										
											2024-04-23 06:37:14 +08:00
										 |  |  | 	return role == "system" || role == "user" || role == "assistant" | 
					
						
							| 
									
										
										
										
											2023-07-18 05:21:27 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2024-04-27 08:11:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func isValidCommand(cmd string) bool { | 
					
						
							|  |  |  | 	switch strings.ToLower(cmd) { | 
					
						
							|  |  |  | 	case "from", "license", "template", "system", "adapter", "parameter", "message": | 
					
						
							|  |  |  | 		return true | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return false | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-11 08:14:08 +08:00
										 |  |  | func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) { | 
					
						
							| 
									
										
										
										
											2025-01-15 11:01:24 +08:00
										 |  |  | 	if filepath.IsAbs(path) || strings.HasPrefix(path, "\\") || strings.HasPrefix(path, "/") { | 
					
						
							|  |  |  | 		return filepath.Abs(path) | 
					
						
							|  |  |  | 	} else if strings.HasPrefix(path, "~") { | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 		var homeDir string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if path == "~" || strings.HasPrefix(path, "~/") { | 
					
						
							|  |  |  | 			// Current user's home directory
 | 
					
						
							|  |  |  | 			currentUser, err := currentUserFunc() | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return "", fmt.Errorf("failed to get current user: %w", err) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			homeDir = currentUser.HomeDir | 
					
						
							|  |  |  | 			path = strings.TrimPrefix(path, "~") | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			// Specific user's home directory
 | 
					
						
							|  |  |  | 			parts := strings.SplitN(path[1:], "/", 2) | 
					
						
							|  |  |  | 			userInfo, err := lookupUserFunc(parts[0]) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return "", fmt.Errorf("failed to find user '%s': %w", parts[0], err) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			homeDir = userInfo.HomeDir | 
					
						
							|  |  |  | 			if len(parts) > 1 { | 
					
						
							|  |  |  | 				path = "/" + parts[1] | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				path = "" | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		path = filepath.Join(homeDir, path) | 
					
						
							| 
									
										
										
										
											2025-01-11 08:14:08 +08:00
										 |  |  | 	} else { | 
					
						
							|  |  |  | 		path = filepath.Join(relativeDir, path) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return filepath.Abs(path) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-11 08:14:08 +08:00
										 |  |  | func expandPath(path, relativeDir string) (string, error) { | 
					
						
							|  |  |  | 	return expandPathImpl(path, relativeDir, user.Current, user.Lookup) | 
					
						
							| 
									
										
										
										
											2025-01-01 10:02:30 +08:00
										 |  |  | } |