mirror of https://github.com/ollama/ollama.git
				
				
				
			Revert "chat api (#991)" while context variable is fixed
This reverts commit 7a0899d62d.
			
			
This commit is contained in:
		
							parent
							
								
									f1ef3f9947
								
							
						
					
					
						commit
						00d06619a1
					
				|  | @ -221,19 +221,6 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate | |||
| 	}) | ||||
| } | ||||
| 
 | ||||
| type ChatResponseFunc func(ChatResponse) error | ||||
| 
 | ||||
| func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error { | ||||
| 	return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error { | ||||
| 		var resp ChatResponse | ||||
| 		if err := json.Unmarshal(bts, &resp); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		return fn(resp) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| type PullProgressFunc func(ProgressResponse) error | ||||
| 
 | ||||
| func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { | ||||
|  |  | |||
							
								
								
									
										74
									
								
								api/types.go
								
								
								
								
							
							
						
						
									
										74
									
								
								api/types.go
								
								
								
								
							|  | @ -36,7 +36,7 @@ type GenerateRequest struct { | |||
| 	Prompt   string `json:"prompt"` | ||||
| 	System   string `json:"system"` | ||||
| 	Template string `json:"template"` | ||||
| 	Context  []int  `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history
 | ||||
| 	Context  []int  `json:"context,omitempty"` | ||||
| 	Stream   *bool  `json:"stream,omitempty"` | ||||
| 	Raw      bool   `json:"raw,omitempty"` | ||||
| 	Format   string `json:"format"` | ||||
|  | @ -44,41 +44,6 @@ type GenerateRequest struct { | |||
| 	Options map[string]interface{} `json:"options"` | ||||
| } | ||||
| 
 | ||||
| type ChatRequest struct { | ||||
| 	Model    string    `json:"model"` | ||||
| 	Messages []Message `json:"messages"` | ||||
| 	Template string    `json:"template"` | ||||
| 	Stream   *bool     `json:"stream,omitempty"` | ||||
| 	Format   string    `json:"format"` | ||||
| 
 | ||||
| 	Options map[string]interface{} `json:"options"` | ||||
| } | ||||
| 
 | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` // one of ["system", "user", "assistant"]
 | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
| 
 | ||||
| type ChatResponse struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	CreatedAt time.Time `json:"created_at"` | ||||
| 	Message   *Message  `json:"message,omitempty"` | ||||
| 
 | ||||
| 	Done    bool  `json:"done"` | ||||
| 	Context []int `json:"context,omitempty"` | ||||
| 
 | ||||
| 	EvalMetrics | ||||
| } | ||||
| 
 | ||||
| type EvalMetrics struct { | ||||
| 	TotalDuration      time.Duration `json:"total_duration,omitempty"` | ||||
| 	LoadDuration       time.Duration `json:"load_duration,omitempty"` | ||||
| 	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"` | ||||
| 	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` | ||||
| 	EvalCount          int           `json:"eval_count,omitempty"` | ||||
| 	EvalDuration       time.Duration `json:"eval_duration,omitempty"` | ||||
| } | ||||
| 
 | ||||
| // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
 | ||||
| type Options struct { | ||||
| 	Runner | ||||
|  | @ -208,34 +173,39 @@ type GenerateResponse struct { | |||
| 	Done    bool  `json:"done"` | ||||
| 	Context []int `json:"context,omitempty"` | ||||
| 
 | ||||
| 	EvalMetrics | ||||
| 	TotalDuration      time.Duration `json:"total_duration,omitempty"` | ||||
| 	LoadDuration       time.Duration `json:"load_duration,omitempty"` | ||||
| 	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"` | ||||
| 	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` | ||||
| 	EvalCount          int           `json:"eval_count,omitempty"` | ||||
| 	EvalDuration       time.Duration `json:"eval_duration,omitempty"` | ||||
| } | ||||
| 
 | ||||
| func (m *EvalMetrics) Summary() { | ||||
| 	if m.TotalDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "total duration:       %v\n", m.TotalDuration) | ||||
| func (r *GenerateResponse) Summary() { | ||||
| 	if r.TotalDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "total duration:       %v\n", r.TotalDuration) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.LoadDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "load duration:        %v\n", m.LoadDuration) | ||||
| 	if r.LoadDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "load duration:        %v\n", r.LoadDuration) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.PromptEvalCount > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", m.PromptEvalCount) | ||||
| 	if r.PromptEvalCount > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", r.PromptEvalCount) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.PromptEvalDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration) | ||||
| 		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds()) | ||||
| 	if r.PromptEvalDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) | ||||
| 		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.EvalCount > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", m.EvalCount) | ||||
| 	if r.EvalCount > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", r.EvalCount) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.EvalDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "eval duration:        %s\n", m.EvalDuration) | ||||
| 		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds()) | ||||
| 	if r.EvalDuration > 0 { | ||||
| 		fmt.Fprintf(os.Stderr, "eval duration:        %s\n", r.EvalDuration) | ||||
| 		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										152
									
								
								docs/api.md
								
								
								
								
							
							
						
						
									
										152
									
								
								docs/api.md
								
								
								
								
							|  | @ -24,7 +24,7 @@ All durations are returned in nanoseconds. | |||
| 
 | ||||
| ### Streaming responses | ||||
| 
 | ||||
| Certain endpoints stream responses as JSON objects. | ||||
| Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character. | ||||
| 
 | ||||
| ## Generate a completion | ||||
| 
 | ||||
|  | @ -32,12 +32,10 @@ Certain endpoints stream responses as JSON objects. | |||
| POST /api/generate | ||||
| ``` | ||||
| 
 | ||||
| Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. | ||||
| Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request. | ||||
| 
 | ||||
| ### Parameters | ||||
| 
 | ||||
| `model` is required. | ||||
| 
 | ||||
| - `model`: (required) the [model name](#model-names) | ||||
| - `prompt`: the prompt to generate a response for | ||||
| 
 | ||||
|  | @ -45,10 +43,11 @@ Advanced parameters (optional): | |||
| 
 | ||||
| - `format`: the format to return a response in. Currently the only accepted value is `json` | ||||
| - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` | ||||
| - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) | ||||
| - `system`: system prompt to (overrides what is defined in the `Modelfile`) | ||||
| - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) | ||||
| - `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory | ||||
| - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects | ||||
| - `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API. | ||||
| - `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself. | ||||
| 
 | ||||
| ### JSON mode | ||||
| 
 | ||||
|  | @ -58,7 +57,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur | |||
| 
 | ||||
| ### Examples | ||||
| 
 | ||||
| #### Request (Prompt) | ||||
| #### Request | ||||
| 
 | ||||
| ```shell | ||||
| curl http://localhost:11434/api/generate -d '{ | ||||
|  | @ -90,7 +89,7 @@ The final response in the stream also includes additional data about the generat | |||
| - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt | ||||
| - `eval_count`: number of tokens the response | ||||
| - `eval_duration`: time in nanoseconds spent generating the response | ||||
| - `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory | ||||
| - `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory | ||||
| - `response`: empty if the response was streamed, if not streamed, this will contain the full response | ||||
| 
 | ||||
| To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. | ||||
|  | @ -115,8 +114,6 @@ To calculate how fast the response is generated in tokens per second (token/s), | |||
| 
 | ||||
| #### Request (No streaming) | ||||
| 
 | ||||
| A response can be recieved in one reply when streaming is off. | ||||
| 
 | ||||
| ```shell | ||||
| curl http://localhost:11434/api/generate -d '{ | ||||
|   "model": "llama2", | ||||
|  | @ -147,9 +144,9 @@ If `stream` is set to `false`, the response will be a single JSON object: | |||
| } | ||||
| ``` | ||||
| 
 | ||||
| #### Request (Raw Mode) | ||||
| #### Request (Raw mode) | ||||
| 
 | ||||
| In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting. | ||||
| In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context. | ||||
| 
 | ||||
| ```shell | ||||
| curl http://localhost:11434/api/generate -d '{ | ||||
|  | @ -167,7 +164,6 @@ curl http://localhost:11434/api/generate -d '{ | |||
|   "model": "mistral", | ||||
|   "created_at": "2023-11-03T15:36:02.583064Z", | ||||
|   "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.", | ||||
|   "context": [1, 2, 3], | ||||
|   "done": true, | ||||
|   "total_duration": 14648695333, | ||||
|   "load_duration": 3302671417, | ||||
|  | @ -279,6 +275,7 @@ curl http://localhost:11434/api/generate -d '{ | |||
|   "model": "llama2", | ||||
|   "created_at": "2023-08-04T19:22:45.499127Z", | ||||
|   "response": "The sky is blue because it is the color of the sky.", | ||||
|   "context": [1, 2, 3], | ||||
|   "done": true, | ||||
|   "total_duration": 5589157167, | ||||
|   "load_duration": 3013701500, | ||||
|  | @ -291,135 +288,6 @@ curl http://localhost:11434/api/generate -d '{ | |||
| } | ||||
| ``` | ||||
| 
 | ||||
| ## Send Chat Messages | ||||
| ```shell | ||||
| POST /api/chat | ||||
| ``` | ||||
| 
 | ||||
| Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. | ||||
| 
 | ||||
| ### Parameters | ||||
| 
 | ||||
| `model` is required. | ||||
| 
 | ||||
| - `model`: (required) the [model name](#model-names) | ||||
| - `messages`: the messages of the chat, this can be used to keep a chat memory | ||||
| 
 | ||||
| Advanced parameters (optional): | ||||
| 
 | ||||
| - `format`: the format to return a response in. Currently the only accepted value is `json` | ||||
| - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` | ||||
| - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) | ||||
| - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects | ||||
| 
 | ||||
| ### Examples | ||||
| 
 | ||||
| #### Request | ||||
| Send a chat message with a streaming response. | ||||
| 
 | ||||
| ```shell | ||||
| curl http://localhost:11434/api/generate -d '{ | ||||
|   "model": "llama2", | ||||
|   "messages": [ | ||||
|     { | ||||
|       "role": "user", | ||||
|       "content": "why is the sky blue?" | ||||
|     } | ||||
|   ] | ||||
| }' | ||||
| ``` | ||||
| 
 | ||||
| #### Response | ||||
| 
 | ||||
| A stream of JSON objects is returned: | ||||
| 
 | ||||
| ```json | ||||
| { | ||||
|   "model": "llama2", | ||||
|   "created_at": "2023-08-04T08:52:19.385406455-07:00", | ||||
|   "message": { | ||||
|     "role": "assisant", | ||||
|     "content": "The" | ||||
|   }, | ||||
|   "done": false | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| Final response: | ||||
| 
 | ||||
| ```json | ||||
| { | ||||
|   "model": "llama2", | ||||
|   "created_at": "2023-08-04T19:22:45.499127Z", | ||||
|   "done": true, | ||||
|   "total_duration": 5589157167, | ||||
|   "load_duration": 3013701500, | ||||
|   "sample_count": 114, | ||||
|   "sample_duration": 81442000, | ||||
|   "prompt_eval_count": 46, | ||||
|   "prompt_eval_duration": 1160282000, | ||||
|   "eval_count": 113, | ||||
|   "eval_duration": 1325948000 | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| #### Request (With History) | ||||
| Send a chat message with a conversation history. | ||||
| 
 | ||||
| ```shell | ||||
| curl http://localhost:11434/api/generate -d '{ | ||||
|   "model": "llama2", | ||||
|   "messages": [ | ||||
|     { | ||||
|       "role": "user", | ||||
|       "content": "why is the sky blue?" | ||||
|     }, | ||||
|     { | ||||
|       "role": "assistant", | ||||
|       "content": "due to rayleigh scattering." | ||||
|     }, | ||||
|     { | ||||
|       "role": "user", | ||||
|       "content": "how is that different than mie scattering?" | ||||
|     } | ||||
|   ] | ||||
| }' | ||||
| ``` | ||||
| 
 | ||||
| #### Response | ||||
| 
 | ||||
| A stream of JSON objects is returned: | ||||
| 
 | ||||
| ```json | ||||
| { | ||||
|   "model": "llama2", | ||||
|   "created_at": "2023-08-04T08:52:19.385406455-07:00", | ||||
|   "message": { | ||||
|     "role": "assisant", | ||||
|     "content": "The" | ||||
|   }, | ||||
|   "done": false | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| Final response: | ||||
| 
 | ||||
| ```json | ||||
| { | ||||
|   "model": "llama2", | ||||
|   "created_at": "2023-08-04T19:22:45.499127Z", | ||||
|   "done": true, | ||||
|   "total_duration": 5589157167, | ||||
|   "load_duration": 3013701500, | ||||
|   "sample_count": 114, | ||||
|   "sample_duration": 81442000, | ||||
|   "prompt_eval_count": 46, | ||||
|   "prompt_eval_duration": 1160282000, | ||||
|   "eval_count": 113, | ||||
|   "eval_duration": 1325948000 | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| ## Create a Model | ||||
| 
 | ||||
| ```shell | ||||
|  |  | |||
							
								
								
									
										52
									
								
								llm/llama.go
								
								
								
								
							
							
						
						
									
										52
									
								
								llm/llama.go
								
								
								
								
							|  | @ -531,31 +531,21 @@ type prediction struct { | |||
| 
 | ||||
| const maxBufferSize = 512 * format.KiloByte | ||||
| 
 | ||||
| type PredictRequest struct { | ||||
| 	Model            string | ||||
| 	Prompt           string | ||||
| 	Format           string | ||||
| 	CheckpointStart  time.Time | ||||
| 	CheckpointLoaded time.Time | ||||
| func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error { | ||||
| 	prevConvo, err := llm.Decode(ctx, prevContext) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| type PredictResponse struct { | ||||
| 	Model              string | ||||
| 	CreatedAt          time.Time | ||||
| 	TotalDuration      time.Duration | ||||
| 	LoadDuration       time.Duration | ||||
| 	Content            string | ||||
| 	Done               bool | ||||
| 	PromptEvalCount    int | ||||
| 	PromptEvalDuration time.Duration | ||||
| 	EvalCount          int | ||||
| 	EvalDuration       time.Duration | ||||
| 	Context            []int | ||||
| } | ||||
| 	// Remove leading spaces from prevConvo if present
 | ||||
| 	prevConvo = strings.TrimPrefix(prevConvo, " ") | ||||
| 
 | ||||
| 	var nextContext strings.Builder | ||||
| 	nextContext.WriteString(prevConvo) | ||||
| 	nextContext.WriteString(prompt) | ||||
| 
 | ||||
| func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error { | ||||
| 	request := map[string]any{ | ||||
| 		"prompt":            predict.Prompt, | ||||
| 		"prompt":            nextContext.String(), | ||||
| 		"stream":            true, | ||||
| 		"n_predict":         llm.NumPredict, | ||||
| 		"n_keep":            llm.NumKeep, | ||||
|  | @ -577,7 +567,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P | |||
| 		"stop":              llm.Stop, | ||||
| 	} | ||||
| 
 | ||||
| 	if predict.Format == "json" { | ||||
| 	if format == "json" { | ||||
| 		request["grammar"] = jsonGrammar | ||||
| 	} | ||||
| 
 | ||||
|  | @ -634,25 +624,25 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P | |||
| 				} | ||||
| 
 | ||||
| 				if p.Content != "" { | ||||
| 					fn(PredictResponse{ | ||||
| 						Model:     predict.Model, | ||||
| 						CreatedAt: time.Now().UTC(), | ||||
| 						Content:   p.Content, | ||||
| 					}) | ||||
| 					fn(api.GenerateResponse{Response: p.Content}) | ||||
| 					nextContext.WriteString(p.Content) | ||||
| 				} | ||||
| 
 | ||||
| 				if p.Stop { | ||||
| 					fn(PredictResponse{ | ||||
| 						Model:         predict.Model, | ||||
| 						CreatedAt:     time.Now().UTC(), | ||||
| 						TotalDuration: time.Since(predict.CheckpointStart), | ||||
| 					embd, err := llm.Encode(ctx, nextContext.String()) | ||||
| 					if err != nil { | ||||
| 						return fmt.Errorf("encoding context: %v", err) | ||||
| 					} | ||||
| 
 | ||||
| 					fn(api.GenerateResponse{ | ||||
| 						Done:               true, | ||||
| 						Context:            embd, | ||||
| 						PromptEvalCount:    p.Timings.PromptN, | ||||
| 						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), | ||||
| 						EvalCount:          p.Timings.PredictedN, | ||||
| 						EvalDuration:       parseDurationMs(p.Timings.PredictedMS), | ||||
| 					}) | ||||
| 
 | ||||
| 					return nil | ||||
| 				} | ||||
| 			} | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ import ( | |||
| ) | ||||
| 
 | ||||
| type LLM interface { | ||||
| 	Predict(context.Context, PredictRequest, func(PredictResponse)) error | ||||
| 	Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error | ||||
| 	Embedding(context.Context, string) ([]float64, error) | ||||
| 	Encode(context.Context, string) ([]int, error) | ||||
| 	Decode(context.Context, []int) (string, error) | ||||
|  |  | |||
|  | @ -47,82 +47,37 @@ type Model struct { | |||
| 	Options       map[string]interface{} | ||||
| } | ||||
| 
 | ||||
| type PromptVars struct { | ||||
| 	System   string | ||||
| 	Prompt   string | ||||
| 	Response string | ||||
| func (m *Model) Prompt(request api.GenerateRequest) (string, error) { | ||||
| 	t := m.Template | ||||
| 	if request.Template != "" { | ||||
| 		t = request.Template | ||||
| 	} | ||||
| 
 | ||||
| func (m *Model) Prompt(p PromptVars) (string, error) { | ||||
| 	var prompt strings.Builder | ||||
| 	tmpl, err := template.New("").Parse(m.Template) | ||||
| 	tmpl, err := template.New("").Parse(t) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if p.System == "" { | ||||
| 		// use the default system prompt for this model if one is not specified
 | ||||
| 		p.System = m.System | ||||
| 	var vars struct { | ||||
| 		First  bool | ||||
| 		System string | ||||
| 		Prompt string | ||||
| 	} | ||||
| 
 | ||||
| 	vars.First = len(request.Context) == 0 | ||||
| 	vars.System = m.System | ||||
| 	vars.Prompt = request.Prompt | ||||
| 
 | ||||
| 	if request.System != "" { | ||||
| 		vars.System = request.System | ||||
| 	} | ||||
| 
 | ||||
| 	var sb strings.Builder | ||||
| 	if err := tmpl.Execute(&sb, p); err != nil { | ||||
| 	if err := tmpl.Execute(&sb, vars); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	prompt.WriteString(sb.String()) | ||||
| 	prompt.WriteString(p.Response) | ||||
| 	return prompt.String(), nil | ||||
| } | ||||
| 
 | ||||
| func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { | ||||
| 	// build the prompt from the list of messages
 | ||||
| 	var prompt strings.Builder | ||||
| 	currentVars := PromptVars{} | ||||
| 
 | ||||
| 	writePrompt := func() error { | ||||
| 		p, err := m.Prompt(currentVars) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		prompt.WriteString(p) | ||||
| 		currentVars = PromptVars{} | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	for _, msg := range msgs { | ||||
| 		switch msg.Role { | ||||
| 		case "system": | ||||
| 			if currentVars.Prompt != "" || currentVars.System != "" { | ||||
| 				if err := writePrompt(); err != nil { | ||||
| 					return "", err | ||||
| 				} | ||||
| 			} | ||||
| 			currentVars.System = msg.Content | ||||
| 		case "user": | ||||
| 			if currentVars.Prompt != "" || currentVars.System != "" { | ||||
| 				if err := writePrompt(); err != nil { | ||||
| 					return "", err | ||||
| 				} | ||||
| 			} | ||||
| 			currentVars.Prompt = msg.Content | ||||
| 		case "assistant": | ||||
| 			currentVars.Response = msg.Content | ||||
| 			if err := writePrompt(); err != nil { | ||||
| 				return "", err | ||||
| 			} | ||||
| 		default: | ||||
| 			return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Append the last set of vars if they are non-empty
 | ||||
| 	if currentVars.Prompt != "" || currentVars.System != "" { | ||||
| 		if err := writePrompt(); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return prompt.String(), nil | ||||
| 	return sb.String(), nil | ||||
| } | ||||
| 
 | ||||
| type ManifestV2 struct { | ||||
|  |  | |||
|  | @ -2,15 +2,17 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jmorganca/ollama/api" | ||||
| ) | ||||
| 
 | ||||
| func TestModelPrompt(t *testing.T) { | ||||
| 	m := Model{ | ||||
| 	var m Model | ||||
| 	req := api.GenerateRequest{ | ||||
| 		Template: "a{{ .Prompt }}b", | ||||
| 	} | ||||
| 	s, err := m.Prompt(PromptVars{ | ||||
| 		Prompt:   "<h1>", | ||||
| 	}) | ||||
| 	} | ||||
| 	s, err := m.Prompt(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  |  | |||
							
								
								
									
										311
									
								
								server/routes.go
								
								
								
								
							
							
						
						
									
										311
									
								
								server/routes.go
								
								
								
								
							|  | @ -60,26 +60,17 @@ var loaded struct { | |||
| var defaultSessionDuration = 5 * time.Minute | ||||
| 
 | ||||
| // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 | ||||
| func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) { | ||||
| 	model, err := GetModel(modelName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	workDir := c.GetString("workDir") | ||||
| 
 | ||||
| func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { | ||||
| 	opts := api.DefaultOptions() | ||||
| 	if err := opts.FromMap(model.Options); err != nil { | ||||
| 		log.Printf("could not load model options: %v", err) | ||||
| 		return nil, err | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := opts.FromMap(reqOpts); err != nil { | ||||
| 		return nil, err | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := c.Request.Context() | ||||
| 
 | ||||
| 	// check if the loaded model is still running in a subprocess, in case something unexpected happened
 | ||||
| 	if loaded.runner != nil { | ||||
| 		if err := loaded.runner.Ping(ctx); err != nil { | ||||
|  | @ -115,7 +106,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess | |||
| 				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) | ||||
| 			} | ||||
| 
 | ||||
| 			return nil, err | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		loaded.Model = model | ||||
|  | @ -149,7 +140,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess | |||
| 	} | ||||
| 
 | ||||
| 	loaded.expireTimer.Reset(sessionDuration) | ||||
| 	return model, nil | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func GenerateHandler(c *gin.Context) { | ||||
|  | @ -182,262 +173,88 @@ func GenerateHandler(c *gin.Context) { | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	sessionDuration := defaultSessionDuration | ||||
| 	model, err := load(c, req.Model, req.Options, sessionDuration) | ||||
| 	model, err := GetModel(req.Model) | ||||
| 	if err != nil { | ||||
| 		var pErr *fs.PathError | ||||
| 		switch { | ||||
| 		case errors.As(err, &pErr): | ||||
| 		if errors.As(err, &pErr) { | ||||
| 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) | ||||
| 		case errors.Is(err, api.ErrInvalidOpts): | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		default: | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	workDir := c.GetString("workDir") | ||||
| 
 | ||||
| 	// TODO: set this duration from the request if specified
 | ||||
| 	sessionDuration := defaultSessionDuration | ||||
| 	if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil { | ||||
| 		if errors.Is(err, api.ErrInvalidOpts) { | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	checkpointLoaded := time.Now() | ||||
| 
 | ||||
| 	prompt := req.Prompt | ||||
| 	if !req.Raw { | ||||
| 		prompt, err = model.Prompt(req) | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	ch := make(chan any) | ||||
| 	go func() { | ||||
| 		defer close(ch) | ||||
| 		// an empty request loads the model
 | ||||
| 		if req.Prompt == "" && req.Template == "" && req.System == "" { | ||||
| 		c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) | ||||
| 			ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true} | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 	checkpointLoaded := time.Now() | ||||
| 
 | ||||
| 	var prompt string | ||||
| 	sendContext := false | ||||
| 	switch { | ||||
| 	case req.Raw: | ||||
| 		prompt = req.Prompt | ||||
| 	case req.Prompt != "": | ||||
| 		if req.Template != "" { | ||||
| 			// override the default model template
 | ||||
| 			model.Template = req.Template | ||||
| 		} | ||||
| 
 | ||||
| 		var rebuild strings.Builder | ||||
| 		if req.Context != nil { | ||||
| 			// TODO: context is deprecated, at some point the context logic within this conditional should be removed
 | ||||
| 			prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context) | ||||
| 			if err != nil { | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			// Remove leading spaces from prevCtx if present
 | ||||
| 			prevCtx = strings.TrimPrefix(prevCtx, " ") | ||||
| 			rebuild.WriteString(prevCtx) | ||||
| 		} | ||||
| 		p, err := model.Prompt(PromptVars{ | ||||
| 			System: req.System, | ||||
| 			Prompt: req.Prompt, | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		rebuild.WriteString(p) | ||||
| 		prompt = rebuild.String() | ||||
| 		sendContext = true | ||||
| 	} | ||||
| 
 | ||||
| 	ch := make(chan any) | ||||
| 	var generated strings.Builder | ||||
| 	go func() { | ||||
| 		defer close(ch) | ||||
| 
 | ||||
| 		fn := func(r llm.PredictResponse) { | ||||
| 			// Update model expiration
 | ||||
| 		fn := func(r api.GenerateResponse) { | ||||
| 			loaded.expireAt = time.Now().Add(sessionDuration) | ||||
| 			loaded.expireTimer.Reset(sessionDuration) | ||||
| 
 | ||||
| 			// Build up the full response
 | ||||
| 			if _, err := generated.WriteString(r.Content); err != nil { | ||||
| 				ch <- gin.H{"error": err.Error()} | ||||
| 				return | ||||
| 			r.Model = req.Model | ||||
| 			r.CreatedAt = time.Now().UTC() | ||||
| 			if r.Done { | ||||
| 				r.TotalDuration = time.Since(checkpointStart) | ||||
| 				r.LoadDuration = checkpointLoaded.Sub(checkpointStart) | ||||
| 			} | ||||
| 
 | ||||
| 			resp := api.GenerateResponse{ | ||||
| 				Model:     r.Model, | ||||
| 				CreatedAt: r.CreatedAt, | ||||
| 				Done:      r.Done, | ||||
| 				Response:  r.Content, | ||||
| 				EvalMetrics: api.EvalMetrics{ | ||||
| 					TotalDuration:      r.TotalDuration, | ||||
| 					LoadDuration:       r.LoadDuration, | ||||
| 					PromptEvalCount:    r.PromptEvalCount, | ||||
| 					PromptEvalDuration: r.PromptEvalDuration, | ||||
| 					EvalCount:          r.EvalCount, | ||||
| 					EvalDuration:       r.EvalDuration, | ||||
| 				}, | ||||
| 			if req.Raw { | ||||
| 				// in raw mode the client must manage history on their own
 | ||||
| 				r.Context = nil | ||||
| 			} | ||||
| 
 | ||||
| 			if r.Done && sendContext { | ||||
| 				embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String()) | ||||
| 				if err != nil { | ||||
| 					ch <- gin.H{"error": err.Error()} | ||||
| 					return | ||||
| 				} | ||||
| 				r.Context = embd | ||||
| 			ch <- r | ||||
| 		} | ||||
| 
 | ||||
| 			ch <- resp | ||||
| 		} | ||||
| 
 | ||||
| 		// Start prediction
 | ||||
| 		predictReq := llm.PredictRequest{ | ||||
| 			Model:            model.Name, | ||||
| 			Prompt:           prompt, | ||||
| 			Format:           req.Format, | ||||
| 			CheckpointStart:  checkpointStart, | ||||
| 			CheckpointLoaded: checkpointLoaded, | ||||
| 		} | ||||
| 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { | ||||
| 		if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil { | ||||
| 			ch <- gin.H{"error": err.Error()} | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	if req.Stream != nil && !*req.Stream { | ||||
| 		// Wait for the channel to close
 | ||||
| 		var r api.GenerateResponse | ||||
| 		var sb strings.Builder | ||||
| 		var response api.GenerateResponse | ||||
| 		generated := "" | ||||
| 		for resp := range ch { | ||||
| 			var ok bool | ||||
| 			if r, ok = resp.(api.GenerateResponse); !ok { | ||||
| 			if r, ok := resp.(api.GenerateResponse); ok { | ||||
| 				generated += r.Response | ||||
| 				response = r | ||||
| 			} else { | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 				return | ||||
| 			} | ||||
| 			sb.WriteString(r.Response) | ||||
| 		} | ||||
| 		r.Response = sb.String() | ||||
| 		c.JSON(http.StatusOK, r) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	streamResponse(c, ch) | ||||
| } | ||||
| 
 | ||||
| func ChatHandler(c *gin.Context) { | ||||
| 	loaded.mu.Lock() | ||||
| 	defer loaded.mu.Unlock() | ||||
| 
 | ||||
| 	checkpointStart := time.Now() | ||||
| 
 | ||||
| 	var req api.ChatRequest | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	switch { | ||||
| 	case errors.Is(err, io.EOF): | ||||
| 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) | ||||
| 		return | ||||
| 	case err != nil: | ||||
| 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// validate the request
 | ||||
| 	switch { | ||||
| 	case req.Model == "": | ||||
| 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) | ||||
| 		return | ||||
| 	case len(req.Format) > 0 && req.Format != "json": | ||||
| 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	sessionDuration := defaultSessionDuration | ||||
| 	model, err := load(c, req.Model, req.Options, sessionDuration) | ||||
| 	if err != nil { | ||||
| 		var pErr *fs.PathError | ||||
| 		switch { | ||||
| 		case errors.As(err, &pErr): | ||||
| 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) | ||||
| 		case errors.Is(err, api.ErrInvalidOpts): | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		default: | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// an empty request loads the model
 | ||||
| 	if len(req.Messages) == 0 { | ||||
| 		c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	checkpointLoaded := time.Now() | ||||
| 
 | ||||
| 	if req.Template != "" { | ||||
| 		// override the default model template
 | ||||
| 		model.Template = req.Template | ||||
| 	} | ||||
| 	prompt, err := model.ChatPrompt(req.Messages) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	ch := make(chan any) | ||||
| 
 | ||||
| 	go func() { | ||||
| 		defer close(ch) | ||||
| 
 | ||||
| 		fn := func(r llm.PredictResponse) { | ||||
| 			// Update model expiration
 | ||||
| 			loaded.expireAt = time.Now().Add(sessionDuration) | ||||
| 			loaded.expireTimer.Reset(sessionDuration) | ||||
| 
 | ||||
| 			resp := api.ChatResponse{ | ||||
| 				Model:     r.Model, | ||||
| 				CreatedAt: r.CreatedAt, | ||||
| 				Done:      r.Done, | ||||
| 				EvalMetrics: api.EvalMetrics{ | ||||
| 					TotalDuration:      r.TotalDuration, | ||||
| 					LoadDuration:       r.LoadDuration, | ||||
| 					PromptEvalCount:    r.PromptEvalCount, | ||||
| 					PromptEvalDuration: r.PromptEvalDuration, | ||||
| 					EvalCount:          r.EvalCount, | ||||
| 					EvalDuration:       r.EvalDuration, | ||||
| 				}, | ||||
| 			} | ||||
| 
 | ||||
| 			if !r.Done { | ||||
| 				resp.Message = &api.Message{Role: "assistant", Content: r.Content} | ||||
| 			} | ||||
| 
 | ||||
| 			ch <- resp | ||||
| 		} | ||||
| 
 | ||||
| 		// Start prediction
 | ||||
| 		predictReq := llm.PredictRequest{ | ||||
| 			Model:            model.Name, | ||||
| 			Prompt:           prompt, | ||||
| 			Format:           req.Format, | ||||
| 			CheckpointStart:  checkpointStart, | ||||
| 			CheckpointLoaded: checkpointLoaded, | ||||
| 		} | ||||
| 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { | ||||
| 			ch <- gin.H{"error": err.Error()} | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	if req.Stream != nil && !*req.Stream { | ||||
| 		// Wait for the channel to close
 | ||||
| 		var r api.ChatResponse | ||||
| 		var sb strings.Builder | ||||
| 		for resp := range ch { | ||||
| 			var ok bool | ||||
| 			if r, ok = resp.(api.ChatResponse); !ok { | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 				return | ||||
| 			} | ||||
| 			if r.Message != nil { | ||||
| 				sb.WriteString(r.Message.Content) | ||||
| 			} | ||||
| 		} | ||||
| 		r.Message = &api.Message{Role: "assistant", Content: sb.String()} | ||||
| 		c.JSON(http.StatusOK, r) | ||||
| 		response.Response = generated | ||||
| 		c.JSON(http.StatusOK, response) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -464,18 +281,15 @@ func EmbeddingHandler(c *gin.Context) { | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	sessionDuration := defaultSessionDuration | ||||
| 	_, err = load(c, req.Model, req.Options, sessionDuration) | ||||
| 	model, err := GetModel(req.Model) | ||||
| 	if err != nil { | ||||
| 		var pErr *fs.PathError | ||||
| 		switch { | ||||
| 		case errors.As(err, &pErr): | ||||
| 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) | ||||
| 		case errors.Is(err, api.ErrInvalidOpts): | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		default: | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	workDir := c.GetString("workDir") | ||||
| 	if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -953,7 +767,6 @@ func Serve(ln net.Listener, allowOrigins []string) error { | |||
| 
 | ||||
| 	r.POST("/api/pull", PullModelHandler) | ||||
| 	r.POST("/api/generate", GenerateHandler) | ||||
| 	r.POST("/api/chat", ChatHandler) | ||||
| 	r.POST("/api/embeddings", EmbeddingHandler) | ||||
| 	r.POST("/api/create", CreateModelHandler) | ||||
| 	r.POST("/api/push", PushModelHandler) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue