mirror of https://github.com/ollama/ollama.git
Merge a7236cd5fc
into bc71278670
This commit is contained in:
commit
43170b7b7d
256
cmd/cmd.go
256
cmd/cmd.go
|
@ -22,6 +22,7 @@ import (
|
|||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
@ -1030,73 +1031,215 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Check if parallel flag exists before accessing it
|
||||
// This is needed because PullHandler can be called from RunHandler
|
||||
// which doesn't have the parallel flag defined
|
||||
parallel := false
|
||||
if cmd.Flags().Lookup("parallel") != nil {
|
||||
parallel, err = cmd.Flags().GetBool("parallel")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
if parallel {
|
||||
// Parallel downloads
|
||||
return pullModelsParallel(cmd, client, args, insecure)
|
||||
} else {
|
||||
// Sequential downloads (existing behavior)
|
||||
return pullModelsSequential(cmd, client, args, insecure)
|
||||
}
|
||||
}
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
func pullModelsSequential(cmd *cobra.Command, client *api.Client, args []string, insecure bool) error {
|
||||
// Process each model sequentially
|
||||
for i, modelName := range args {
|
||||
if i > 0 {
|
||||
fmt.Printf("\n") // Add spacing between models
|
||||
}
|
||||
|
||||
fmt.Printf("pulling %s...\n", modelName)
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
// This is the initial status update for the
|
||||
// layer, which the server sends before
|
||||
// beginning the download, for clients to
|
||||
// compute total size and prepare for
|
||||
// downloads, if needed.
|
||||
//
|
||||
// Skipping this here to avoid showing a 0%
|
||||
// progress bar, which *should* clue the user
|
||||
// into the fact that many things are being
|
||||
// downloaded and that the current active
|
||||
// download is not that last. However, in rare
|
||||
// cases it seems to be triggering to some, and
|
||||
// it isn't worth explaining, so just ignore
|
||||
// and regress to the old UI that keeps giving
|
||||
// you the "But wait, there is more!" after
|
||||
// each "100% done" bar, which is "better."
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
// This is the initial status update for the
|
||||
// layer, which the server sends before
|
||||
// beginning the download, for clients to
|
||||
// compute total size and prepare for
|
||||
// downloads, if needed.
|
||||
//
|
||||
// Skipping this here to avoid showing a 0%
|
||||
// progress bar, which *should* clue the user
|
||||
// into the fact that many things are being
|
||||
// downloaded and that the current active
|
||||
// download is not that last. However, in rare
|
||||
// cases it seems to be triggering to some, and
|
||||
// it isn't worth explaining, so just ignore
|
||||
// and regress to the old UI that keeps giving
|
||||
// you the "But wait, there is more!" after
|
||||
// each "100% done" bar, which is "better."
|
||||
return nil
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
request := api.PullRequest{Name: modelName, Insecure: insecure}
|
||||
err := client.Pull(cmd.Context(), &request, fn)
|
||||
p.Stop()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", modelName, err)
|
||||
}
|
||||
|
||||
fmt.Printf("success\n")
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
return client.Pull(cmd.Context(), &request, fn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullModelsParallel(cmd *cobra.Command, client *api.Client, args []string, insecure bool) error {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
errChan := make(chan error, len(args))
|
||||
|
||||
fmt.Printf("pulling %d models in parallel...\n", len(args))
|
||||
|
||||
// Create a single shared progress instance for all models
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Start a goroutine for each model
|
||||
for _, modelName := range args {
|
||||
wg.Add(1)
|
||||
go func(name string) {
|
||||
defer wg.Done()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// Lock to prevent progress output collision
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
spinner = nil
|
||||
}
|
||||
|
||||
// Use model name + digest as unique key to avoid conflicts
|
||||
barKey := fmt.Sprintf("%s-%s", name, resp.Digest)
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
digestName, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
digestName = strings.TrimSpace(digestName)
|
||||
if isDigest {
|
||||
digestName = digestName[:min(12, len(digestName))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("[%s] pulling %s:", name, digestName), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(barKey, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinnerKey := fmt.Sprintf("%s-status", name)
|
||||
spinner = progress.NewSpinner(fmt.Sprintf("[%s] %s", name, status))
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
fmt.Printf("starting download: %s\n", name)
|
||||
mu.Unlock()
|
||||
|
||||
request := api.PullRequest{Name: name, Insecure: insecure}
|
||||
err := client.Pull(cmd.Context(), &request, fn)
|
||||
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to pull %s: %w", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
fmt.Printf("completed: %s\n", name)
|
||||
mu.Unlock()
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
// Wait for all downloads to complete
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Check for errors
|
||||
var errors []error
|
||||
for err := range errChan {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
fmt.Printf("\nErrors occurred during parallel downloads:\n")
|
||||
for _, err := range errors {
|
||||
fmt.Printf(" - %v\n", err)
|
||||
}
|
||||
return fmt.Errorf("%d model(s) failed to download", len(errors))
|
||||
}
|
||||
|
||||
fmt.Printf("\nAll models downloaded successfully!\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
@ -1697,14 +1840,15 @@ func NewCLI() *cobra.Command {
|
|||
}
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull MODEL",
|
||||
Use: "pull MODEL [MODEL...]",
|
||||
Short: "Pull a model from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PullHandler,
|
||||
}
|
||||
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
pullCmd.Flags().Bool("parallel", false, "Download models in parallel")
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push MODEL",
|
||||
|
@ -1903,4 +2047,4 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
|||
out += readline.ColorDefault
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
|
@ -79,11 +79,12 @@ func AllowedOrigins() (origins []string) {
|
|||
return origins
|
||||
}
|
||||
|
||||
// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
|
||||
// Default is $HOME/.ollama/models
|
||||
// Models returns the primary path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
|
||||
// For multiple paths, use ModelPaths() instead. Default is $HOME/.ollama/models
|
||||
func Models() string {
|
||||
if s := Var("OLLAMA_MODELS"); s != "" {
|
||||
return s
|
||||
paths := ModelPaths()
|
||||
if len(paths) > 0 {
|
||||
return paths[0]
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
|
@ -94,6 +95,46 @@ func Models() string {
|
|||
return filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
|
||||
// ModelPaths returns a list of paths to model directories. Paths can be configured via the OLLAMA_MODELS environment variable,
|
||||
// using colon-separated paths (similar to PATH). The first path is considered the primary path for storing new models.
|
||||
// Default is $HOME/.ollama/models
|
||||
func ModelPaths() []string {
|
||||
if s := Var("OLLAMA_MODELS"); s != "" {
|
||||
// Split by colon (Unix-style) or semicolon (Windows-style)
|
||||
separator := ":"
|
||||
if runtime.GOOS == "windows" {
|
||||
separator = ";"
|
||||
}
|
||||
paths := strings.Split(s, separator)
|
||||
|
||||
// Clean and validate each path
|
||||
var validPaths []string
|
||||
for _, path := range paths {
|
||||
path = strings.TrimSpace(path)
|
||||
if path != "" {
|
||||
// Convert to absolute path
|
||||
if absPath, err := filepath.Abs(path); err == nil {
|
||||
validPaths = append(validPaths, absPath)
|
||||
} else {
|
||||
validPaths = append(validPaths, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(validPaths) > 0 {
|
||||
return validPaths
|
||||
}
|
||||
}
|
||||
|
||||
// Default fallback
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return []string{filepath.Join(home, ".ollama", "models")}
|
||||
}
|
||||
|
||||
// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable.
|
||||
// Negative values are treated as infinite. Zero is treated as no keep alive.
|
||||
// Default is 5 minutes.
|
||||
|
@ -279,7 +320,7 @@ func AsMap() map[string]EnvVar {
|
|||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path(s) to the models directory (colon-separated for multiple paths)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
|
@ -327,3 +368,45 @@ func Values() map[string]string {
|
|||
func Var(key string) string {
|
||||
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
|
||||
}
|
||||
|
||||
// ValidateModelPaths checks if all configured model paths exist and are accessible
|
||||
func ValidateModelPaths() error {
|
||||
paths := ModelPaths()
|
||||
var errors []string
|
||||
|
||||
for _, path := range paths {
|
||||
if info, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Try to create the directory
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("cannot create model path %s: %v", path, err))
|
||||
}
|
||||
} else {
|
||||
errors = append(errors, fmt.Sprintf("cannot access model path %s: %v", path, err))
|
||||
}
|
||||
} else if !info.IsDir() {
|
||||
errors = append(errors, fmt.Sprintf("model path %s is not a directory", path))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("model path validation failed: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrimaryModelPath returns the first (primary) model path where new models should be stored
|
||||
func PrimaryModelPath() string {
|
||||
paths := ModelPaths()
|
||||
if len(paths) > 0 {
|
||||
return paths[0]
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
|
|
36
go.mod
36
go.mod
|
@ -5,7 +5,6 @@ go 1.24.0
|
|||
require (
|
||||
github.com/containerd/console v1.0.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
|
@ -30,46 +29,48 @@ require (
|
|||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
|
@ -77,5 +78,4 @@ require (
|
|||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
|
@ -332,7 +332,7 @@ func GetModel(name string) (*Model, error) {
|
|||
}
|
||||
|
||||
if manifest.Config.Digest != "" {
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
filename, err := FindBlobPath(manifest.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -349,7 +349,7 @@ func GetModel(name string) (*Model, error) {
|
|||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
filename, err := FindBlobPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
|||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := FindBlobPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
|||
return nil, errors.New("opening layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := FindBlobPath(l.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -121,7 +121,8 @@ func (l *Layer) Remove() error {
|
|||
}
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
// Try to find and remove the blob from any location
|
||||
blob, err := FindBlobPath(l.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -65,15 +65,14 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
// Try to find manifest in any of the model paths
|
||||
manifestPath, err := FindManifestPath(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := filepath.Join(manifests, n.Filepath())
|
||||
|
||||
var m Manifest
|
||||
f, err := os.Open(p)
|
||||
f, err := os.Open(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -89,7 +88,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
m.filepath = p
|
||||
m.filepath = manifestPath
|
||||
m.fi = fi
|
||||
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
|
||||
|
||||
|
@ -124,53 +123,70 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
|||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(mxyng): use something less brittle
|
||||
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||
manifestPaths, err := GetManifestPaths()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ms := make(map[model.Name]*Manifest)
|
||||
for _, match := range matches {
|
||||
fi, err := os.Stat(match)
|
||||
|
||||
// Search through all manifest directories
|
||||
for _, manifestDir := range manifestPaths {
|
||||
// TODO(mxyng): use something less brittle
|
||||
matches, err := filepath.Glob(filepath.Join(manifestDir, "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !continueOnError {
|
||||
return nil, err
|
||||
}
|
||||
slog.Warn("failed to glob manifests", "path", manifestDir, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !fi.IsDir() {
|
||||
rel, err := filepath.Rel(manifests, match)
|
||||
for _, match := range matches {
|
||||
fi, err := os.Stat(match)
|
||||
if err != nil {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", match, err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Warn("bad filepath", "path", match, "error", err)
|
||||
slog.Warn("failed to stat manifest", "path", match, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
n := model.ParseNameFromFilepath(rel)
|
||||
if !n.IsValid() {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", rel, err)
|
||||
if !fi.IsDir() {
|
||||
rel, err := filepath.Rel(manifestDir, match)
|
||||
if err != nil {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", match, err)
|
||||
}
|
||||
slog.Warn("bad filepath", "path", match, "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Warn("bad manifest name", "path", rel)
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", n, err)
|
||||
n := model.ParseNameFromFilepath(rel)
|
||||
if !n.IsValid() {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", rel, err)
|
||||
}
|
||||
slog.Warn("bad manifest name", "path", rel)
|
||||
continue
|
||||
}
|
||||
slog.Warn("bad manifest", "name", n, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
ms[n] = m
|
||||
// Skip if we already found this model in a higher priority path
|
||||
if _, exists := ms[n]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", n, err)
|
||||
}
|
||||
slog.Warn("bad manifest", "name", n, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
ms[n] = m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -103,7 +103,9 @@ func (mp ModelPath) GetManifestPath() (string, error) {
|
|||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
|
||||
// Use the new multi-path search function
|
||||
return FindManifestPath(name)
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
|
@ -122,6 +124,19 @@ func GetManifestPath() (string, error) {
|
|||
return path, nil
|
||||
}
|
||||
|
||||
// GetManifestPaths returns all manifest directories from all model paths
|
||||
func GetManifestPaths() ([]string, error) {
|
||||
modelPaths := envconfig.ModelPaths()
|
||||
manifestPaths := make([]string, 0, len(modelPaths))
|
||||
|
||||
for _, modelPath := range modelPaths {
|
||||
path := filepath.Join(modelPath, "manifests")
|
||||
manifestPaths = append(manifestPaths, path)
|
||||
}
|
||||
|
||||
return manifestPaths, nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
|
@ -144,3 +159,80 @@ func GetBlobsPath(digest string) (string, error) {
|
|||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// GetBlobsPaths returns all blob directories from all model paths
|
||||
func GetBlobsPaths() ([]string, error) {
|
||||
modelPaths := envconfig.ModelPaths()
|
||||
blobsPaths := make([]string, 0, len(modelPaths))
|
||||
|
||||
for _, modelPath := range modelPaths {
|
||||
path := filepath.Join(modelPath, "blobs")
|
||||
blobsPaths = append(blobsPaths, path)
|
||||
}
|
||||
|
||||
return blobsPaths, nil
|
||||
}
|
||||
|
||||
// FindBlobPath searches for a blob file across all model paths and returns the first found path
|
||||
func FindBlobPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
modelPaths := envconfig.ModelPaths()
|
||||
|
||||
// Search through all model paths
|
||||
for _, modelPath := range modelPaths {
|
||||
path := filepath.Join(modelPath, "blobs", digest)
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, return the path in the first (primary) model directory
|
||||
if len(modelPaths) > 0 {
|
||||
path := filepath.Join(modelPaths[0], "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no model paths configured")
|
||||
}
|
||||
|
||||
// FindManifestPath searches for a manifest file across all model paths and returns the first found path
|
||||
func FindManifestPath(name model.Name) (string, error) {
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
|
||||
modelPaths := envconfig.ModelPaths()
|
||||
filePath := name.Filepath()
|
||||
|
||||
// Search through all model paths
|
||||
for _, modelPath := range modelPaths {
|
||||
path := filepath.Join(modelPath, "manifests", filePath)
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, return the path in the first (primary) model directory
|
||||
if len(modelPaths) > 0 {
|
||||
path := filepath.Join(modelPaths[0], "manifests", filePath)
|
||||
return path, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no model paths configured")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue