mirror of https://github.com/ollama/ollama.git
				
				
				
			Merge pull request #4547 from dhiltgen/load_progress
Wire up load progress
This commit is contained in:
		
						commit
						95b1133d0c
					
				|  | @ -334,6 +334,7 @@ struct server_metrics { | ||||||
| struct llama_server_context | struct llama_server_context | ||||||
| { | { | ||||||
|     llama_model *model = nullptr; |     llama_model *model = nullptr; | ||||||
|  |     float modelProgress = 0.0; | ||||||
|     llama_context *ctx = nullptr; |     llama_context *ctx = nullptr; | ||||||
| 
 | 
 | ||||||
|     clip_ctx *clp_ctx = nullptr; |     clip_ctx *clp_ctx = nullptr; | ||||||
|  | @ -2779,6 +2780,12 @@ inline void signal_handler(int signal) { | ||||||
|     shutdown_handler(signal); |     shutdown_handler(signal); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | static bool update_load_progress(float progress, void *data) | ||||||
|  | { | ||||||
|  |     ((llama_server_context*)data)->modelProgress = progress; | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| #if defined(_WIN32) | #if defined(_WIN32) | ||||||
| char* wchar_to_char(const wchar_t* wstr) { | char* wchar_to_char(const wchar_t* wstr) { | ||||||
|     if (wstr == nullptr) return nullptr; |     if (wstr == nullptr) return nullptr; | ||||||
|  | @ -2884,7 +2891,9 @@ int main(int argc, char **argv) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             case SERVER_STATE_LOADING_MODEL: |             case SERVER_STATE_LOADING_MODEL: | ||||||
|                 res.set_content(R"({"status": "loading model"})", "application/json"); |                 char buf[128]; | ||||||
|  |                 snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress); | ||||||
|  |                 res.set_content(buf, "application/json"); | ||||||
|                 res.status = 503; // HTTP Service Unavailable
 |                 res.status = 503; // HTTP Service Unavailable
 | ||||||
|                 break; |                 break; | ||||||
|             case SERVER_STATE_ERROR: |             case SERVER_STATE_ERROR: | ||||||
|  | @ -3079,6 +3088,9 @@ int main(int argc, char **argv) { | ||||||
|             }); |             }); | ||||||
| 
 | 
 | ||||||
|     // load the model
 |     // load the model
 | ||||||
|  |     params.progress_callback = update_load_progress; | ||||||
|  |     params.progress_callback_user_data = (void*)&llama; | ||||||
|  | 
 | ||||||
|     if (!llama.load_model(params)) |     if (!llama.load_model(params)) | ||||||
|     { |     { | ||||||
|         state.store(SERVER_STATE_ERROR); |         state.store(SERVER_STATE_ERROR); | ||||||
|  |  | ||||||
|  | @ -0,0 +1,31 @@ | ||||||
|  | diff --git a/common/common.cpp b/common/common.cpp
 | ||||||
|  | index ba1ecf0e..cead57cc 100644
 | ||||||
|  | --- a/common/common.cpp
 | ||||||
|  | +++ b/common/common.cpp
 | ||||||
|  | @@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
 | ||||||
|  |      mparams.use_mmap        = params.use_mmap; | ||||||
|  |      mparams.use_mlock       = params.use_mlock; | ||||||
|  |      mparams.check_tensors   = params.check_tensors; | ||||||
|  | +    mparams.progress_callback = params.progress_callback;
 | ||||||
|  | +    mparams.progress_callback_user_data = params.progress_callback_user_data;
 | ||||||
|  |      if (params.kv_overrides.empty()) { | ||||||
|  |          mparams.kv_overrides = NULL; | ||||||
|  |      } else { | ||||||
|  | diff --git a/common/common.h b/common/common.h
 | ||||||
|  | index d80344f2..71e84834 100644
 | ||||||
|  | --- a/common/common.h
 | ||||||
|  | +++ b/common/common.h
 | ||||||
|  | @@ -174,6 +174,13 @@ struct gpt_params {
 | ||||||
|  |      // multimodal models (see examples/llava) | ||||||
|  |      std::string mmproj = "";        // path to multimodal projector | ||||||
|  |      std::vector<std::string> image; // path to image file(s) | ||||||
|  | +
 | ||||||
|  | +    // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
 | ||||||
|  | +    // If the provided progress_callback returns true, model loading continues.
 | ||||||
|  | +    // If it returns false, model loading is immediately aborted.
 | ||||||
|  | +    llama_progress_callback progress_callback = NULL;
 | ||||||
|  | +    // context pointer passed to the progress callback
 | ||||||
|  | +    void * progress_callback_user_data;
 | ||||||
|  |  }; | ||||||
|  |   | ||||||
|  |  void gpt_params_handle_model_default(gpt_params & params); | ||||||
|  | @ -55,6 +55,7 @@ type llmServer struct { | ||||||
| 	totalLayers    uint64 | 	totalLayers    uint64 | ||||||
| 	gpuCount       int | 	gpuCount       int | ||||||
| 	loadDuration   time.Duration // Record how long it took the model to load
 | 	loadDuration   time.Duration // Record how long it took the model to load
 | ||||||
|  | 	loadProgress   float32 | ||||||
| 
 | 
 | ||||||
| 	sem *semaphore.Weighted | 	sem *semaphore.Weighted | ||||||
| } | } | ||||||
|  | @ -425,10 +426,11 @@ func (s ServerStatus) ToString() string { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type ServerStatusResp struct { | type ServerStatusResp struct { | ||||||
| 	Status          string `json:"status"` | 	Status          string  `json:"status"` | ||||||
| 	SlotsIdle       int    `json:"slots_idle"` | 	SlotsIdle       int     `json:"slots_idle"` | ||||||
| 	SlotsProcessing int    `json:"slots_processing"` | 	SlotsProcessing int     `json:"slots_processing"` | ||||||
| 	Error           string `json:"error"` | 	Error           string  `json:"error"` | ||||||
|  | 	Progress        float32 `json:"progress"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { | func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { | ||||||
|  | @ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { | ||||||
| 	case "no slot available": | 	case "no slot available": | ||||||
| 		return ServerStatusNoSlotsAvailable, nil | 		return ServerStatusNoSlotsAvailable, nil | ||||||
| 	case "loading model": | 	case "loading model": | ||||||
|  | 		s.loadProgress = status.Progress | ||||||
| 		return ServerStatusLoadingModel, nil | 		return ServerStatusLoadingModel, nil | ||||||
| 	default: | 	default: | ||||||
| 		return ServerStatusError, fmt.Errorf("server error: %+v", status) | 		return ServerStatusError, fmt.Errorf("server error: %+v", status) | ||||||
|  | @ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error { | ||||||
| 
 | 
 | ||||||
| func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | ||||||
| 	start := time.Now() | 	start := time.Now() | ||||||
| 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
 | 	stallDuration := 60 * time.Second | ||||||
|  | 	stallTimer := time.Now().Add(stallDuration) // give up if we stall for
 | ||||||
| 
 | 
 | ||||||
| 	slog.Info("waiting for llama runner to start responding") | 	slog.Info("waiting for llama runner to start responding") | ||||||
| 	var lastStatus ServerStatus = -1 | 	var lastStatus ServerStatus = -1 | ||||||
|  | @ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | ||||||
| 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) | 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) | ||||||
| 		default: | 		default: | ||||||
| 		} | 		} | ||||||
| 		if time.Now().After(expiresAt) { | 		if time.Now().After(stallTimer) { | ||||||
| 			// timeout
 | 			// timeout
 | ||||||
| 			msg := "" | 			msg := "" | ||||||
| 			if s.status != nil && s.status.LastErrMsg != "" { | 			if s.status != nil && s.status.LastErrMsg != "" { | ||||||
| 				msg = s.status.LastErrMsg | 				msg = s.status.LastErrMsg | ||||||
| 			} | 			} | ||||||
| 			return fmt.Errorf("timed out waiting for llama runner to start: %s", msg) | 			return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg) | ||||||
| 		} | 		} | ||||||
| 		if s.cmd.ProcessState != nil { | 		if s.cmd.ProcessState != nil { | ||||||
| 			msg := "" | 			msg := "" | ||||||
|  | @ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | ||||||
| 		} | 		} | ||||||
| 		ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) | 		ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) | ||||||
| 		defer cancel() | 		defer cancel() | ||||||
|  | 		priorProgress := s.loadProgress | ||||||
| 		status, _ := s.getServerStatus(ctx) | 		status, _ := s.getServerStatus(ctx) | ||||||
| 		if lastStatus != status && status != ServerStatusReady { | 		if lastStatus != status && status != ServerStatusReady { | ||||||
| 			// Only log on status changes
 | 			// Only log on status changes
 | ||||||
|  | @ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | ||||||
| 			return nil | 			return nil | ||||||
| 		default: | 		default: | ||||||
| 			lastStatus = status | 			lastStatus = status | ||||||
|  | 			// Reset the timer as long as we're making forward progress on the load
 | ||||||
|  | 			if priorProgress != s.loadProgress { | ||||||
|  | 				slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) | ||||||
|  | 				stallTimer = time.Now().Add(stallDuration) | ||||||
|  | 			} | ||||||
| 			time.Sleep(time.Millisecond * 250) | 			time.Sleep(time.Millisecond * 250) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue