This commit is contained in:
A-Akhil 2025-10-07 15:14:49 +00:00 committed by GitHub
commit 43170b7b7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 455 additions and 119 deletions

View File

@ -22,6 +22,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@ -1030,73 +1031,215 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return err
}
// Check if parallel flag exists before accessing it
// This is needed because PullHandler can be called from RunHandler
// which doesn't have the parallel flag defined
parallel := false
if cmd.Flags().Lookup("parallel") != nil {
parallel, err = cmd.Flags().GetBool("parallel")
if err != nil {
return err
}
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
if parallel {
// Parallel downloads
return pullModelsParallel(cmd, client, args, insecure)
} else {
// Sequential downloads (existing behavior)
return pullModelsSequential(cmd, client, args, insecure)
}
}
bars := make(map[string]*progress.Bar)
func pullModelsSequential(cmd *cobra.Command, client *api.Client, args []string, insecure bool) error {
// Process each model sequentially
for i, modelName := range args {
if i > 0 {
fmt.Printf("\n") // Add spacing between models
}
fmt.Printf("pulling %s...\n", modelName)
p := progress.NewProgress(os.Stderr)
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
return nil
}
return nil
request := api.PullRequest{Name: modelName, Insecure: insecure}
err := client.Pull(cmd.Context(), &request, fn)
p.Stop()
if err != nil {
return fmt.Errorf("failed to pull %s: %w", modelName, err)
}
fmt.Printf("success\n")
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
return client.Pull(cmd.Context(), &request, fn)
return nil
}
func pullModelsParallel(cmd *cobra.Command, client *api.Client, args []string, insecure bool) error {
var wg sync.WaitGroup
var mu sync.Mutex
errChan := make(chan error, len(args))
fmt.Printf("pulling %d models in parallel...\n", len(args))
// Create a single shared progress instance for all models
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Start a goroutine for each model
for _, modelName := range args {
wg.Add(1)
go func(name string) {
defer wg.Done()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
// Lock to prevent progress output collision
mu.Lock()
defer mu.Unlock()
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
spinner = nil
}
// Use model name + digest as unique key to avoid conflicts
barKey := fmt.Sprintf("%s-%s", name, resp.Digest)
bar, ok := bars[resp.Digest]
if !ok {
digestName, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
digestName = strings.TrimSpace(digestName)
if isDigest {
digestName = digestName[:min(12, len(digestName))]
}
bar = progress.NewBar(fmt.Sprintf("[%s] pulling %s:", name, digestName), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(barKey, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinnerKey := fmt.Sprintf("%s-status", name)
spinner = progress.NewSpinner(fmt.Sprintf("[%s] %s", name, status))
p.Add(spinnerKey, spinner)
}
return nil
}
mu.Lock()
fmt.Printf("starting download: %s\n", name)
mu.Unlock()
request := api.PullRequest{Name: name, Insecure: insecure}
err := client.Pull(cmd.Context(), &request, fn)
if err != nil {
errChan <- fmt.Errorf("failed to pull %s: %w", name, err)
return
}
mu.Lock()
fmt.Printf("completed: %s\n", name)
mu.Unlock()
}(modelName)
}
// Wait for all downloads to complete
wg.Wait()
close(errChan)
// Check for errors
var errors []error
for err := range errChan {
errors = append(errors, err)
}
if len(errors) > 0 {
fmt.Printf("\nErrors occurred during parallel downloads:\n")
for _, err := range errors {
fmt.Printf(" - %v\n", err)
}
return fmt.Errorf("%d model(s) failed to download", len(errors))
}
fmt.Printf("\nAll models downloaded successfully!\n")
return nil
}
type generateContextKey string
@ -1697,14 +1840,15 @@ func NewCLI() *cobra.Command {
}
pullCmd := &cobra.Command{
Use: "pull MODEL",
Use: "pull MODEL [MODEL...]",
Short: "Pull a model from a registry",
Args: cobra.ExactArgs(1),
Args: cobra.MinimumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: PullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd.Flags().Bool("parallel", false, "Download models in parallel")
pushCmd := &cobra.Command{
Use: "push MODEL",
@ -1903,4 +2047,4 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
out += readline.ColorDefault
}
return out
}
}

View File

@ -79,11 +79,12 @@ func AllowedOrigins() (origins []string) {
return origins
}
// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
// Default is $HOME/.ollama/models
// Models returns the primary path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
// For multiple paths, use ModelPaths() instead. Default is $HOME/.ollama/models
func Models() string {
if s := Var("OLLAMA_MODELS"); s != "" {
return s
paths := ModelPaths()
if len(paths) > 0 {
return paths[0]
}
home, err := os.UserHomeDir()
@ -94,6 +95,46 @@ func Models() string {
return filepath.Join(home, ".ollama", "models")
}
// ModelPaths returns a list of paths to model directories. Paths can be configured via the OLLAMA_MODELS environment variable,
// using colon-separated paths (similar to PATH). The first path is considered the primary path for storing new models.
// Default is $HOME/.ollama/models
func ModelPaths() []string {
if s := Var("OLLAMA_MODELS"); s != "" {
// Split by colon (Unix-style) or semicolon (Windows-style)
separator := ":"
if runtime.GOOS == "windows" {
separator = ";"
}
paths := strings.Split(s, separator)
// Clean and validate each path
var validPaths []string
for _, path := range paths {
path = strings.TrimSpace(path)
if path != "" {
// Convert to absolute path
if absPath, err := filepath.Abs(path); err == nil {
validPaths = append(validPaths, absPath)
} else {
validPaths = append(validPaths, path)
}
}
}
if len(validPaths) > 0 {
return validPaths
}
}
// Default fallback
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return []string{filepath.Join(home, ".ollama", "models")}
}
// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable.
// Negative values are treated as infinite. Zero is treated as no keep alive.
// Default is 5 minutes.
@ -279,7 +320,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path(s) to the models directory (colon-separated for multiple paths)"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
@ -327,3 +368,45 @@ func Values() map[string]string {
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}
// ValidateModelPaths checks if all configured model paths exist and are accessible
func ValidateModelPaths() error {
paths := ModelPaths()
var errors []string
for _, path := range paths {
if info, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
// Try to create the directory
if err := os.MkdirAll(path, 0o755); err != nil {
errors = append(errors, fmt.Sprintf("cannot create model path %s: %v", path, err))
}
} else {
errors = append(errors, fmt.Sprintf("cannot access model path %s: %v", path, err))
}
} else if !info.IsDir() {
errors = append(errors, fmt.Sprintf("model path %s is not a directory", path))
}
}
if len(errors) > 0 {
return fmt.Errorf("model path validation failed: %s", strings.Join(errors, "; "))
}
return nil
}
// PrimaryModelPath returns the first (primary) model path where new models should be stored
func PrimaryModelPath() string {
paths := ModelPaths()
if len(paths) > 0 {
return paths[0]
}
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return filepath.Join(home, ".ollama", "models")
}

36
go.mod
View File

@ -5,7 +5,6 @@ go 1.24.0
require (
github.com/containerd/console v1.0.3
github.com/gin-gonic/gin v1.10.0
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/uuid v1.6.0
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
@ -30,46 +29,48 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)
require (
github.com/bytedance/sonic v1.11.6 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)
require (
github.com/gin-contrib/cors v1.7.2
golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.38.0 // indirect
@ -77,5 +78,4 @@ require (
golang.org/x/term v0.30.0
golang.org/x/text v0.23.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -332,7 +332,7 @@ func GetModel(name string) (*Model, error) {
}
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
filename, err := FindBlobPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
@ -349,7 +349,7 @@ func GetModel(name string) (*Model, error) {
}
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
filename, err := FindBlobPath(layer.Digest)
if err != nil {
return nil, err
}

View File

@ -69,7 +69,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest)
blob, err := FindBlobPath(digest)
if err != nil {
return Layer{}, err
}
@ -93,7 +93,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return nil, errors.New("opening layer with empty digest")
}
blob, err := GetBlobsPath(l.Digest)
blob, err := FindBlobPath(l.Digest)
if err != nil {
return nil, err
}
@ -121,7 +121,8 @@ func (l *Layer) Remove() error {
}
}
blob, err := GetBlobsPath(l.Digest)
// Try to find and remove the blob from any location
blob, err := FindBlobPath(l.Digest)
if err != nil {
return err
}

View File

@ -65,15 +65,14 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n)
}
manifests, err := GetManifestPath()
// Try to find manifest in any of the model paths
manifestPath, err := FindManifestPath(n)
if err != nil {
return nil, err
}
p := filepath.Join(manifests, n.Filepath())
var m Manifest
f, err := os.Open(p)
f, err := os.Open(manifestPath)
if err != nil {
return nil, err
}
@ -89,7 +88,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
m.filepath = p
m.filepath = manifestPath
m.fi = fi
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
@ -124,53 +123,70 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
manifestPaths, err := GetManifestPaths()
if err != nil {
return nil, err
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
// Search through all manifest directories
for _, manifestDir := range manifestPaths {
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifestDir, "*", "*", "*", "*"))
if err != nil {
return nil, err
if !continueOnError {
return nil, err
}
slog.Warn("failed to glob manifests", "path", manifestDir, "error", err)
continue
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
if !continueOnError {
return nil, fmt.Errorf("%s %w", match, err)
return nil, err
}
slog.Warn("bad filepath", "path", match, "error", err)
slog.Warn("failed to stat manifest", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
if !continueOnError {
return nil, fmt.Errorf("%s %w", rel, err)
if !fi.IsDir() {
rel, err := filepath.Rel(manifestDir, match)
if err != nil {
if !continueOnError {
return nil, fmt.Errorf("%s %w", match, err)
}
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
slog.Warn("bad manifest name", "path", rel)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
if !continueOnError {
return nil, fmt.Errorf("%s %w", n, err)
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
if !continueOnError {
return nil, fmt.Errorf("%s %w", rel, err)
}
slog.Warn("bad manifest name", "path", rel)
continue
}
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
// Skip if we already found this model in a higher priority path
if _, exists := ms[n]; exists {
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
if !continueOnError {
return nil, fmt.Errorf("%s %w", n, err)
}
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
}

View File

@ -103,7 +103,9 @@ func (mp ModelPath) GetManifestPath() (string, error) {
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
// Use the new multi-path search function
return FindManifestPath(name)
}
func (mp ModelPath) BaseURL() *url.URL {
@ -122,6 +124,19 @@ func GetManifestPath() (string, error) {
return path, nil
}
// GetManifestPaths returns all manifest directories from all model paths
func GetManifestPaths() ([]string, error) {
modelPaths := envconfig.ModelPaths()
manifestPaths := make([]string, 0, len(modelPaths))
for _, modelPath := range modelPaths {
path := filepath.Join(modelPath, "manifests")
manifestPaths = append(manifestPaths, path)
}
return manifestPaths, nil
}
func GetBlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
@ -144,3 +159,80 @@ func GetBlobsPath(digest string) (string, error) {
return path, nil
}
// GetBlobsPaths returns all blob directories from all model paths
func GetBlobsPaths() ([]string, error) {
modelPaths := envconfig.ModelPaths()
blobsPaths := make([]string, 0, len(modelPaths))
for _, modelPath := range modelPaths {
path := filepath.Join(modelPath, "blobs")
blobsPaths = append(blobsPaths, path)
}
return blobsPaths, nil
}
// FindBlobPath searches for a blob file across all model paths and returns the first found path
func FindBlobPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
modelPaths := envconfig.ModelPaths()
// Search through all model paths
for _, modelPath := range modelPaths {
path := filepath.Join(modelPath, "blobs", digest)
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
// If not found, return the path in the first (primary) model directory
if len(modelPaths) > 0 {
path := filepath.Join(modelPaths[0], "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
return "", fmt.Errorf("no model paths configured")
}
// FindManifestPath searches for a manifest file across all model paths and returns the first found path
func FindManifestPath(name model.Name) (string, error) {
if !name.IsValid() {
return "", fs.ErrNotExist
}
modelPaths := envconfig.ModelPaths()
filePath := name.Filepath()
// Search through all model paths
for _, modelPath := range modelPaths {
path := filepath.Join(modelPath, "manifests", filePath)
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
// If not found, return the path in the first (primary) model directory
if len(modelPaths) > 0 {
path := filepath.Join(modelPaths[0], "manifests", filePath)
return path, nil
}
return "", fmt.Errorf("no model paths configured")
}