mirror of https://github.com/ollama/ollama.git
				
				
				
			Merge pull request #4547 from dhiltgen/load_progress
Wire up load progress
This commit is contained in:
		
						commit
						95b1133d0c
					
				|  | @ -334,6 +334,7 @@ struct server_metrics { | |||
| struct llama_server_context | ||||
| { | ||||
|     llama_model *model = nullptr; | ||||
|     float modelProgress = 0.0; | ||||
|     llama_context *ctx = nullptr; | ||||
| 
 | ||||
|     clip_ctx *clp_ctx = nullptr; | ||||
|  | @ -2779,6 +2780,12 @@ inline void signal_handler(int signal) { | |||
|     shutdown_handler(signal); | ||||
| } | ||||
| 
 | ||||
| static bool update_load_progress(float progress, void *data) | ||||
| { | ||||
|     ((llama_server_context*)data)->modelProgress = progress; | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
| #if defined(_WIN32) | ||||
| char* wchar_to_char(const wchar_t* wstr) { | ||||
|     if (wstr == nullptr) return nullptr; | ||||
|  | @ -2884,7 +2891,9 @@ int main(int argc, char **argv) { | |||
|                 break; | ||||
|             } | ||||
|             case SERVER_STATE_LOADING_MODEL: | ||||
|                 res.set_content(R"({"status": "loading model"})", "application/json"); | ||||
|                 char buf[128]; | ||||
|                 snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress); | ||||
|                 res.set_content(buf, "application/json"); | ||||
|                 res.status = 503; // HTTP Service Unavailable
 | ||||
|                 break; | ||||
|             case SERVER_STATE_ERROR: | ||||
|  | @ -3079,6 +3088,9 @@ int main(int argc, char **argv) { | |||
|             }); | ||||
| 
 | ||||
|     // load the model
 | ||||
|     params.progress_callback = update_load_progress; | ||||
|     params.progress_callback_user_data = (void*)&llama; | ||||
| 
 | ||||
|     if (!llama.load_model(params)) | ||||
|     { | ||||
|         state.store(SERVER_STATE_ERROR); | ||||
|  |  | |||
|  | @ -0,0 +1,31 @@ | |||
| diff --git a/common/common.cpp b/common/common.cpp
 | ||||
| index ba1ecf0e..cead57cc 100644
 | ||||
| --- a/common/common.cpp
 | ||||
| +++ b/common/common.cpp
 | ||||
| @@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
 | ||||
|      mparams.use_mmap        = params.use_mmap; | ||||
|      mparams.use_mlock       = params.use_mlock; | ||||
|      mparams.check_tensors   = params.check_tensors; | ||||
| +    mparams.progress_callback = params.progress_callback;
 | ||||
| +    mparams.progress_callback_user_data = params.progress_callback_user_data;
 | ||||
|      if (params.kv_overrides.empty()) { | ||||
|          mparams.kv_overrides = NULL; | ||||
|      } else { | ||||
| diff --git a/common/common.h b/common/common.h
 | ||||
| index d80344f2..71e84834 100644
 | ||||
| --- a/common/common.h
 | ||||
| +++ b/common/common.h
 | ||||
| @@ -174,6 +174,13 @@ struct gpt_params {
 | ||||
|      // multimodal models (see examples/llava) | ||||
|      std::string mmproj = "";        // path to multimodal projector | ||||
|      std::vector<std::string> image; // path to image file(s) | ||||
| +
 | ||||
| +    // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
 | ||||
| +    // If the provided progress_callback returns true, model loading continues.
 | ||||
| +    // If it returns false, model loading is immediately aborted.
 | ||||
| +    llama_progress_callback progress_callback = NULL;
 | ||||
| +    // context pointer passed to the progress callback
 | ||||
| +    void * progress_callback_user_data;
 | ||||
|  }; | ||||
|   | ||||
|  void gpt_params_handle_model_default(gpt_params & params); | ||||
|  | @ -55,6 +55,7 @@ type llmServer struct { | |||
| 	totalLayers    uint64 | ||||
| 	gpuCount       int | ||||
| 	loadDuration   time.Duration // Record how long it took the model to load
 | ||||
| 	loadProgress   float32 | ||||
| 
 | ||||
| 	sem *semaphore.Weighted | ||||
| } | ||||
|  | @ -429,6 +430,7 @@ type ServerStatusResp struct { | |||
| 	SlotsIdle       int     `json:"slots_idle"` | ||||
| 	SlotsProcessing int     `json:"slots_processing"` | ||||
| 	Error           string  `json:"error"` | ||||
| 	Progress        float32 `json:"progress"` | ||||
| } | ||||
| 
 | ||||
| func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { | ||||
|  | @ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { | |||
| 	case "no slot available": | ||||
| 		return ServerStatusNoSlotsAvailable, nil | ||||
| 	case "loading model": | ||||
| 		s.loadProgress = status.Progress | ||||
| 		return ServerStatusLoadingModel, nil | ||||
| 	default: | ||||
| 		return ServerStatusError, fmt.Errorf("server error: %+v", status) | ||||
|  | @ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error { | |||
| 
 | ||||
| func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | ||||
| 	start := time.Now() | ||||
| 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
 | ||||
| 	stallDuration := 60 * time.Second | ||||
| 	stallTimer := time.Now().Add(stallDuration) // give up if we stall for
 | ||||
| 
 | ||||
| 	slog.Info("waiting for llama runner to start responding") | ||||
| 	var lastStatus ServerStatus = -1 | ||||
|  | @ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | |||
| 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) | ||||
| 		default: | ||||
| 		} | ||||
| 		if time.Now().After(expiresAt) { | ||||
| 		if time.Now().After(stallTimer) { | ||||
| 			// timeout
 | ||||
| 			msg := "" | ||||
| 			if s.status != nil && s.status.LastErrMsg != "" { | ||||
| 				msg = s.status.LastErrMsg | ||||
| 			} | ||||
| 			return fmt.Errorf("timed out waiting for llama runner to start: %s", msg) | ||||
| 			return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg) | ||||
| 		} | ||||
| 		if s.cmd.ProcessState != nil { | ||||
| 			msg := "" | ||||
|  | @ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | |||
| 		} | ||||
| 		ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) | ||||
| 		defer cancel() | ||||
| 		priorProgress := s.loadProgress | ||||
| 		status, _ := s.getServerStatus(ctx) | ||||
| 		if lastStatus != status && status != ServerStatusReady { | ||||
| 			// Only log on status changes
 | ||||
|  | @ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { | |||
| 			return nil | ||||
| 		default: | ||||
| 			lastStatus = status | ||||
| 			// Reset the timer as long as we're making forward progress on the load
 | ||||
| 			if priorProgress != s.loadProgress { | ||||
| 				slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) | ||||
| 				stallTimer = time.Now().Add(stallDuration) | ||||
| 			} | ||||
| 			time.Sleep(time.Millisecond * 250) | ||||
| 			continue | ||||
| 		} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue