mirror of https://github.com/ollama/ollama.git
				
				
				
			Revert "chat api (#991)" while context variable is fixed
This reverts commit 7a0899d62d.
			
			
This commit is contained in:
		
							parent
							
								
									f1ef3f9947
								
							
						
					
					
						commit
						00d06619a1
					
				| 
						 | 
				
			
			@ -221,19 +221,6 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
 | 
			
		|||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponseFunc func(ChatResponse) error
 | 
			
		||||
 | 
			
		||||
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
 | 
			
		||||
	return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
 | 
			
		||||
		var resp ChatResponse
 | 
			
		||||
		if err := json.Unmarshal(bts, &resp); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return fn(resp)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PullProgressFunc func(ProgressResponse) error
 | 
			
		||||
 | 
			
		||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										74
									
								
								api/types.go
								
								
								
								
							
							
						
						
									
										74
									
								
								api/types.go
								
								
								
								
							| 
						 | 
				
			
			@ -36,7 +36,7 @@ type GenerateRequest struct {
 | 
			
		|||
	Prompt   string `json:"prompt"`
 | 
			
		||||
	System   string `json:"system"`
 | 
			
		||||
	Template string `json:"template"`
 | 
			
		||||
	Context  []int  `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history
 | 
			
		||||
	Context  []int  `json:"context,omitempty"`
 | 
			
		||||
	Stream   *bool  `json:"stream,omitempty"`
 | 
			
		||||
	Raw      bool   `json:"raw,omitempty"`
 | 
			
		||||
	Format   string `json:"format"`
 | 
			
		||||
| 
						 | 
				
			
			@ -44,41 +44,6 @@ type GenerateRequest struct {
 | 
			
		|||
	Options map[string]interface{} `json:"options"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Model    string    `json:"model"`
 | 
			
		||||
	Messages []Message `json:"messages"`
 | 
			
		||||
	Template string    `json:"template"`
 | 
			
		||||
	Stream   *bool     `json:"stream,omitempty"`
 | 
			
		||||
	Format   string    `json:"format"`
 | 
			
		||||
 | 
			
		||||
	Options map[string]interface{} `json:"options"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role    string `json:"role"` // one of ["system", "user", "assistant"]
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponse struct {
 | 
			
		||||
	Model     string    `json:"model"`
 | 
			
		||||
	CreatedAt time.Time `json:"created_at"`
 | 
			
		||||
	Message   *Message  `json:"message,omitempty"`
 | 
			
		||||
 | 
			
		||||
	Done    bool  `json:"done"`
 | 
			
		||||
	Context []int `json:"context,omitempty"`
 | 
			
		||||
 | 
			
		||||
	EvalMetrics
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EvalMetrics struct {
 | 
			
		||||
	TotalDuration      time.Duration `json:"total_duration,omitempty"`
 | 
			
		||||
	LoadDuration       time.Duration `json:"load_duration,omitempty"`
 | 
			
		||||
	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
 | 
			
		||||
	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
 | 
			
		||||
	EvalCount          int           `json:"eval_count,omitempty"`
 | 
			
		||||
	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
 | 
			
		||||
type Options struct {
 | 
			
		||||
	Runner
 | 
			
		||||
| 
						 | 
				
			
			@ -208,34 +173,39 @@ type GenerateResponse struct {
 | 
			
		|||
	Done    bool  `json:"done"`
 | 
			
		||||
	Context []int `json:"context,omitempty"`
 | 
			
		||||
 | 
			
		||||
	EvalMetrics
 | 
			
		||||
	TotalDuration      time.Duration `json:"total_duration,omitempty"`
 | 
			
		||||
	LoadDuration       time.Duration `json:"load_duration,omitempty"`
 | 
			
		||||
	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
 | 
			
		||||
	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
 | 
			
		||||
	EvalCount          int           `json:"eval_count,omitempty"`
 | 
			
		||||
	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *EvalMetrics) Summary() {
 | 
			
		||||
	if m.TotalDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "total duration:       %v\n", m.TotalDuration)
 | 
			
		||||
func (r *GenerateResponse) Summary() {
 | 
			
		||||
	if r.TotalDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "total duration:       %v\n", r.TotalDuration)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.LoadDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "load duration:        %v\n", m.LoadDuration)
 | 
			
		||||
	if r.LoadDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "load duration:        %v\n", r.LoadDuration)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.PromptEvalCount > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", m.PromptEvalCount)
 | 
			
		||||
	if r.PromptEvalCount > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", r.PromptEvalCount)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.PromptEvalDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
 | 
			
		||||
	if r.PromptEvalDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.EvalCount > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", m.EvalCount)
 | 
			
		||||
	if r.EvalCount > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", r.EvalCount)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.EvalDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "eval duration:        %s\n", m.EvalDuration)
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
 | 
			
		||||
	if r.EvalDuration > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "eval duration:        %s\n", r.EvalDuration)
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										152
									
								
								docs/api.md
								
								
								
								
							
							
						
						
									
										152
									
								
								docs/api.md
								
								
								
								
							| 
						 | 
				
			
			@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
 | 
			
		|||
 | 
			
		||||
### Streaming responses
 | 
			
		||||
 | 
			
		||||
Certain endpoints stream responses as JSON objects.
 | 
			
		||||
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
 | 
			
		||||
 | 
			
		||||
## Generate a completion
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -32,12 +32,10 @@ Certain endpoints stream responses as JSON objects.
 | 
			
		|||
POST /api/generate
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
 | 
			
		||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
 | 
			
		||||
 | 
			
		||||
### Parameters
 | 
			
		||||
 | 
			
		||||
`model` is required.
 | 
			
		||||
 | 
			
		||||
- `model`: (required) the [model name](#model-names)
 | 
			
		||||
- `prompt`: the prompt to generate a response for
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -45,10 +43,11 @@ Advanced parameters (optional):
 | 
			
		|||
 | 
			
		||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
 | 
			
		||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
 | 
			
		||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
 | 
			
		||||
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
 | 
			
		||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
 | 
			
		||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
 | 
			
		||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
 | 
			
		||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
 | 
			
		||||
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
 | 
			
		||||
 | 
			
		||||
### JSON mode
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +57,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
 | 
			
		|||
 | 
			
		||||
### Examples
 | 
			
		||||
 | 
			
		||||
#### Request (Prompt)
 | 
			
		||||
#### Request
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
curl http://localhost:11434/api/generate -d '{
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +89,7 @@ The final response in the stream also includes additional data about the generat
 | 
			
		|||
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
 | 
			
		||||
- `eval_count`: number of tokens the response
 | 
			
		||||
- `eval_duration`: time in nanoseconds spent generating the response
 | 
			
		||||
- `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
 | 
			
		||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
 | 
			
		||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
 | 
			
		||||
 | 
			
		||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
 | 
			
		||||
| 
						 | 
				
			
			@ -115,8 +114,6 @@ To calculate how fast the response is generated in tokens per second (token/s),
 | 
			
		|||
 | 
			
		||||
#### Request (No streaming)
 | 
			
		||||
 | 
			
		||||
A response can be recieved in one reply when streaming is off.
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
curl http://localhost:11434/api/generate -d '{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
| 
						 | 
				
			
			@ -147,9 +144,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
 | 
			
		|||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Request (Raw Mode)
 | 
			
		||||
#### Request (Raw mode)
 | 
			
		||||
 | 
			
		||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
 | 
			
		||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
curl http://localhost:11434/api/generate -d '{
 | 
			
		||||
| 
						 | 
				
			
			@ -167,7 +164,6 @@ curl http://localhost:11434/api/generate -d '{
 | 
			
		|||
  "model": "mistral",
 | 
			
		||||
  "created_at": "2023-11-03T15:36:02.583064Z",
 | 
			
		||||
  "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
 | 
			
		||||
  "context": [1, 2, 3],
 | 
			
		||||
  "done": true,
 | 
			
		||||
  "total_duration": 14648695333,
 | 
			
		||||
  "load_duration": 3302671417,
 | 
			
		||||
| 
						 | 
				
			
			@ -279,6 +275,7 @@ curl http://localhost:11434/api/generate -d '{
 | 
			
		|||
  "model": "llama2",
 | 
			
		||||
  "created_at": "2023-08-04T19:22:45.499127Z",
 | 
			
		||||
  "response": "The sky is blue because it is the color of the sky.",
 | 
			
		||||
  "context": [1, 2, 3],
 | 
			
		||||
  "done": true,
 | 
			
		||||
  "total_duration": 5589157167,
 | 
			
		||||
  "load_duration": 3013701500,
 | 
			
		||||
| 
						 | 
				
			
			@ -291,135 +288,6 @@ curl http://localhost:11434/api/generate -d '{
 | 
			
		|||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Send Chat Messages
 | 
			
		||||
```shell
 | 
			
		||||
POST /api/chat
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
 | 
			
		||||
 | 
			
		||||
### Parameters
 | 
			
		||||
 | 
			
		||||
`model` is required.
 | 
			
		||||
 | 
			
		||||
- `model`: (required) the [model name](#model-names)
 | 
			
		||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
 | 
			
		||||
 | 
			
		||||
Advanced parameters (optional):
 | 
			
		||||
 | 
			
		||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
 | 
			
		||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
 | 
			
		||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
 | 
			
		||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
 | 
			
		||||
 | 
			
		||||
### Examples
 | 
			
		||||
 | 
			
		||||
#### Request
 | 
			
		||||
Send a chat message with a streaming response.
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
curl http://localhost:11434/api/generate -d '{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
  "messages": [
 | 
			
		||||
    {
 | 
			
		||||
      "role": "user",
 | 
			
		||||
      "content": "why is the sky blue?"
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}'
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Response
 | 
			
		||||
 | 
			
		||||
A stream of JSON objects is returned:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
  "created_at": "2023-08-04T08:52:19.385406455-07:00",
 | 
			
		||||
  "message": {
 | 
			
		||||
    "role": "assisant",
 | 
			
		||||
    "content": "The"
 | 
			
		||||
  },
 | 
			
		||||
  "done": false
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Final response:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
  "created_at": "2023-08-04T19:22:45.499127Z",
 | 
			
		||||
  "done": true,
 | 
			
		||||
  "total_duration": 5589157167,
 | 
			
		||||
  "load_duration": 3013701500,
 | 
			
		||||
  "sample_count": 114,
 | 
			
		||||
  "sample_duration": 81442000,
 | 
			
		||||
  "prompt_eval_count": 46,
 | 
			
		||||
  "prompt_eval_duration": 1160282000,
 | 
			
		||||
  "eval_count": 113,
 | 
			
		||||
  "eval_duration": 1325948000
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Request (With History)
 | 
			
		||||
Send a chat message with a conversation history.
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
curl http://localhost:11434/api/generate -d '{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
  "messages": [
 | 
			
		||||
    {
 | 
			
		||||
      "role": "user",
 | 
			
		||||
      "content": "why is the sky blue?"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "role": "assistant",
 | 
			
		||||
      "content": "due to rayleigh scattering."
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "role": "user",
 | 
			
		||||
      "content": "how is that different than mie scattering?"
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}'
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Response
 | 
			
		||||
 | 
			
		||||
A stream of JSON objects is returned:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
  "created_at": "2023-08-04T08:52:19.385406455-07:00",
 | 
			
		||||
  "message": {
 | 
			
		||||
    "role": "assisant",
 | 
			
		||||
    "content": "The"
 | 
			
		||||
  },
 | 
			
		||||
  "done": false
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Final response:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "model": "llama2",
 | 
			
		||||
  "created_at": "2023-08-04T19:22:45.499127Z",
 | 
			
		||||
  "done": true,
 | 
			
		||||
  "total_duration": 5589157167,
 | 
			
		||||
  "load_duration": 3013701500,
 | 
			
		||||
  "sample_count": 114,
 | 
			
		||||
  "sample_duration": 81442000,
 | 
			
		||||
  "prompt_eval_count": 46,
 | 
			
		||||
  "prompt_eval_duration": 1160282000,
 | 
			
		||||
  "eval_count": 113,
 | 
			
		||||
  "eval_duration": 1325948000
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Create a Model
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										54
									
								
								llm/llama.go
								
								
								
								
							
							
						
						
									
										54
									
								
								llm/llama.go
								
								
								
								
							| 
						 | 
				
			
			@ -531,31 +531,21 @@ type prediction struct {
 | 
			
		|||
 | 
			
		||||
const maxBufferSize = 512 * format.KiloByte
 | 
			
		||||
 | 
			
		||||
type PredictRequest struct {
 | 
			
		||||
	Model            string
 | 
			
		||||
	Prompt           string
 | 
			
		||||
	Format           string
 | 
			
		||||
	CheckpointStart  time.Time
 | 
			
		||||
	CheckpointLoaded time.Time
 | 
			
		||||
}
 | 
			
		||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
 | 
			
		||||
	prevConvo, err := llm.Decode(ctx, prevContext)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
type PredictResponse struct {
 | 
			
		||||
	Model              string
 | 
			
		||||
	CreatedAt          time.Time
 | 
			
		||||
	TotalDuration      time.Duration
 | 
			
		||||
	LoadDuration       time.Duration
 | 
			
		||||
	Content            string
 | 
			
		||||
	Done               bool
 | 
			
		||||
	PromptEvalCount    int
 | 
			
		||||
	PromptEvalDuration time.Duration
 | 
			
		||||
	EvalCount          int
 | 
			
		||||
	EvalDuration       time.Duration
 | 
			
		||||
	Context            []int
 | 
			
		||||
}
 | 
			
		||||
	// Remove leading spaces from prevConvo if present
 | 
			
		||||
	prevConvo = strings.TrimPrefix(prevConvo, " ")
 | 
			
		||||
 | 
			
		||||
	var nextContext strings.Builder
 | 
			
		||||
	nextContext.WriteString(prevConvo)
 | 
			
		||||
	nextContext.WriteString(prompt)
 | 
			
		||||
 | 
			
		||||
func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error {
 | 
			
		||||
	request := map[string]any{
 | 
			
		||||
		"prompt":            predict.Prompt,
 | 
			
		||||
		"prompt":            nextContext.String(),
 | 
			
		||||
		"stream":            true,
 | 
			
		||||
		"n_predict":         llm.NumPredict,
 | 
			
		||||
		"n_keep":            llm.NumKeep,
 | 
			
		||||
| 
						 | 
				
			
			@ -577,7 +567,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P
 | 
			
		|||
		"stop":              llm.Stop,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if predict.Format == "json" {
 | 
			
		||||
	if format == "json" {
 | 
			
		||||
		request["grammar"] = jsonGrammar
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -634,25 +624,25 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P
 | 
			
		|||
				}
 | 
			
		||||
 | 
			
		||||
				if p.Content != "" {
 | 
			
		||||
					fn(PredictResponse{
 | 
			
		||||
						Model:     predict.Model,
 | 
			
		||||
						CreatedAt: time.Now().UTC(),
 | 
			
		||||
						Content:   p.Content,
 | 
			
		||||
					})
 | 
			
		||||
					fn(api.GenerateResponse{Response: p.Content})
 | 
			
		||||
					nextContext.WriteString(p.Content)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if p.Stop {
 | 
			
		||||
					fn(PredictResponse{
 | 
			
		||||
						Model:         predict.Model,
 | 
			
		||||
						CreatedAt:     time.Now().UTC(),
 | 
			
		||||
						TotalDuration: time.Since(predict.CheckpointStart),
 | 
			
		||||
					embd, err := llm.Encode(ctx, nextContext.String())
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return fmt.Errorf("encoding context: %v", err)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					fn(api.GenerateResponse{
 | 
			
		||||
						Done:               true,
 | 
			
		||||
						Context:            embd,
 | 
			
		||||
						PromptEvalCount:    p.Timings.PromptN,
 | 
			
		||||
						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
 | 
			
		||||
						EvalCount:          p.Timings.PredictedN,
 | 
			
		||||
						EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
 | 
			
		||||
					})
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@ import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
type LLM interface {
 | 
			
		||||
	Predict(context.Context, PredictRequest, func(PredictResponse)) error
 | 
			
		||||
	Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
 | 
			
		||||
	Embedding(context.Context, string) ([]float64, error)
 | 
			
		||||
	Encode(context.Context, string) ([]int, error)
 | 
			
		||||
	Decode(context.Context, []int) (string, error)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,82 +47,37 @@ type Model struct {
 | 
			
		|||
	Options       map[string]interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PromptVars struct {
 | 
			
		||||
	System   string
 | 
			
		||||
	Prompt   string
 | 
			
		||||
	Response string
 | 
			
		||||
}
 | 
			
		||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 | 
			
		||||
	t := m.Template
 | 
			
		||||
	if request.Template != "" {
 | 
			
		||||
		t = request.Template
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
func (m *Model) Prompt(p PromptVars) (string, error) {
 | 
			
		||||
	var prompt strings.Builder
 | 
			
		||||
	tmpl, err := template.New("").Parse(m.Template)
 | 
			
		||||
	tmpl, err := template.New("").Parse(t)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if p.System == "" {
 | 
			
		||||
		// use the default system prompt for this model if one is not specified
 | 
			
		||||
		p.System = m.System
 | 
			
		||||
	var vars struct {
 | 
			
		||||
		First  bool
 | 
			
		||||
		System string
 | 
			
		||||
		Prompt string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	vars.First = len(request.Context) == 0
 | 
			
		||||
	vars.System = m.System
 | 
			
		||||
	vars.Prompt = request.Prompt
 | 
			
		||||
 | 
			
		||||
	if request.System != "" {
 | 
			
		||||
		vars.System = request.System
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sb strings.Builder
 | 
			
		||||
	if err := tmpl.Execute(&sb, p); err != nil {
 | 
			
		||||
	if err := tmpl.Execute(&sb, vars); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	prompt.WriteString(sb.String())
 | 
			
		||||
	prompt.WriteString(p.Response)
 | 
			
		||||
	return prompt.String(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
 | 
			
		||||
	// build the prompt from the list of messages
 | 
			
		||||
	var prompt strings.Builder
 | 
			
		||||
	currentVars := PromptVars{}
 | 
			
		||||
 | 
			
		||||
	writePrompt := func() error {
 | 
			
		||||
		p, err := m.Prompt(currentVars)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		prompt.WriteString(p)
 | 
			
		||||
		currentVars = PromptVars{}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, msg := range msgs {
 | 
			
		||||
		switch msg.Role {
 | 
			
		||||
		case "system":
 | 
			
		||||
			if currentVars.Prompt != "" || currentVars.System != "" {
 | 
			
		||||
				if err := writePrompt(); err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			currentVars.System = msg.Content
 | 
			
		||||
		case "user":
 | 
			
		||||
			if currentVars.Prompt != "" || currentVars.System != "" {
 | 
			
		||||
				if err := writePrompt(); err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			currentVars.Prompt = msg.Content
 | 
			
		||||
		case "assistant":
 | 
			
		||||
			currentVars.Response = msg.Content
 | 
			
		||||
			if err := writePrompt(); err != nil {
 | 
			
		||||
				return "", err
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Append the last set of vars if they are non-empty
 | 
			
		||||
	if currentVars.Prompt != "" || currentVars.System != "" {
 | 
			
		||||
		if err := writePrompt(); err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return prompt.String(), nil
 | 
			
		||||
	return sb.String(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ManifestV2 struct {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,15 +2,17 @@ package server
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jmorganca/ollama/api"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestModelPrompt(t *testing.T) {
 | 
			
		||||
	m := Model{
 | 
			
		||||
	var m Model
 | 
			
		||||
	req := api.GenerateRequest{
 | 
			
		||||
		Template: "a{{ .Prompt }}b",
 | 
			
		||||
		Prompt:   "<h1>",
 | 
			
		||||
	}
 | 
			
		||||
	s, err := m.Prompt(PromptVars{
 | 
			
		||||
		Prompt: "<h1>",
 | 
			
		||||
	})
 | 
			
		||||
	s, err := m.Prompt(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										313
									
								
								server/routes.go
								
								
								
								
							
							
						
						
									
										313
									
								
								server/routes.go
								
								
								
								
							| 
						 | 
				
			
			@ -60,26 +60,17 @@ var loaded struct {
 | 
			
		|||
var defaultSessionDuration = 5 * time.Minute
 | 
			
		||||
 | 
			
		||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 | 
			
		||||
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
 | 
			
		||||
	model, err := GetModel(modelName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	workDir := c.GetString("workDir")
 | 
			
		||||
 | 
			
		||||
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
 | 
			
		||||
	opts := api.DefaultOptions()
 | 
			
		||||
	if err := opts.FromMap(model.Options); err != nil {
 | 
			
		||||
		log.Printf("could not load model options: %v", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := opts.FromMap(reqOpts); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
 | 
			
		||||
	// check if the loaded model is still running in a subprocess, in case something unexpected happened
 | 
			
		||||
	if loaded.runner != nil {
 | 
			
		||||
		if err := loaded.runner.Ping(ctx); err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -115,7 +106,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
 | 
			
		|||
				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return nil, err
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		loaded.Model = model
 | 
			
		||||
| 
						 | 
				
			
			@ -149,7 +140,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	loaded.expireTimer.Reset(sessionDuration)
 | 
			
		||||
	return model, nil
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenerateHandler(c *gin.Context) {
 | 
			
		||||
| 
						 | 
				
			
			@ -182,262 +173,88 @@ func GenerateHandler(c *gin.Context) {
 | 
			
		|||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionDuration := defaultSessionDuration
 | 
			
		||||
	model, err := load(c, req.Model, req.Options, sessionDuration)
 | 
			
		||||
	model, err := GetModel(req.Model)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		var pErr *fs.PathError
 | 
			
		||||
		switch {
 | 
			
		||||
		case errors.As(err, &pErr):
 | 
			
		||||
		if errors.As(err, &pErr) {
 | 
			
		||||
			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 | 
			
		||||
		case errors.Is(err, api.ErrInvalidOpts):
 | 
			
		||||
			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		default:
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// an empty request loads the model
 | 
			
		||||
	if req.Prompt == "" && req.Template == "" && req.System == "" {
 | 
			
		||||
		c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	checkpointLoaded := time.Now()
 | 
			
		||||
 | 
			
		||||
	var prompt string
 | 
			
		||||
	sendContext := false
 | 
			
		||||
	switch {
 | 
			
		||||
	case req.Raw:
 | 
			
		||||
		prompt = req.Prompt
 | 
			
		||||
	case req.Prompt != "":
 | 
			
		||||
		if req.Template != "" {
 | 
			
		||||
			// override the default model template
 | 
			
		||||
			model.Template = req.Template
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var rebuild strings.Builder
 | 
			
		||||
		if req.Context != nil {
 | 
			
		||||
			// TODO: context is deprecated, at some point the context logic within this conditional should be removed
 | 
			
		||||
			prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Remove leading spaces from prevCtx if present
 | 
			
		||||
			prevCtx = strings.TrimPrefix(prevCtx, " ")
 | 
			
		||||
			rebuild.WriteString(prevCtx)
 | 
			
		||||
		}
 | 
			
		||||
		p, err := model.Prompt(PromptVars{
 | 
			
		||||
			System: req.System,
 | 
			
		||||
			Prompt: req.Prompt,
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		rebuild.WriteString(p)
 | 
			
		||||
		prompt = rebuild.String()
 | 
			
		||||
		sendContext = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ch := make(chan any)
 | 
			
		||||
	var generated strings.Builder
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer close(ch)
 | 
			
		||||
 | 
			
		||||
		fn := func(r llm.PredictResponse) {
 | 
			
		||||
			// Update model expiration
 | 
			
		||||
			loaded.expireAt = time.Now().Add(sessionDuration)
 | 
			
		||||
			loaded.expireTimer.Reset(sessionDuration)
 | 
			
		||||
 | 
			
		||||
			// Build up the full response
 | 
			
		||||
			if _, err := generated.WriteString(r.Content); err != nil {
 | 
			
		||||
				ch <- gin.H{"error": err.Error()}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			resp := api.GenerateResponse{
 | 
			
		||||
				Model:     r.Model,
 | 
			
		||||
				CreatedAt: r.CreatedAt,
 | 
			
		||||
				Done:      r.Done,
 | 
			
		||||
				Response:  r.Content,
 | 
			
		||||
				EvalMetrics: api.EvalMetrics{
 | 
			
		||||
					TotalDuration:      r.TotalDuration,
 | 
			
		||||
					LoadDuration:       r.LoadDuration,
 | 
			
		||||
					PromptEvalCount:    r.PromptEvalCount,
 | 
			
		||||
					PromptEvalDuration: r.PromptEvalDuration,
 | 
			
		||||
					EvalCount:          r.EvalCount,
 | 
			
		||||
					EvalDuration:       r.EvalDuration,
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if r.Done && sendContext {
 | 
			
		||||
				embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String())
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					ch <- gin.H{"error": err.Error()}
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				r.Context = embd
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ch <- resp
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Start prediction
 | 
			
		||||
		predictReq := llm.PredictRequest{
 | 
			
		||||
			Model:            model.Name,
 | 
			
		||||
			Prompt:           prompt,
 | 
			
		||||
			Format:           req.Format,
 | 
			
		||||
			CheckpointStart:  checkpointStart,
 | 
			
		||||
			CheckpointLoaded: checkpointLoaded,
 | 
			
		||||
		}
 | 
			
		||||
		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 | 
			
		||||
			ch <- gin.H{"error": err.Error()}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if req.Stream != nil && !*req.Stream {
 | 
			
		||||
		// Wait for the channel to close
 | 
			
		||||
		var r api.GenerateResponse
 | 
			
		||||
		var sb strings.Builder
 | 
			
		||||
		for resp := range ch {
 | 
			
		||||
			var ok bool
 | 
			
		||||
			if r, ok = resp.(api.GenerateResponse); !ok {
 | 
			
		||||
				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			sb.WriteString(r.Response)
 | 
			
		||||
		}
 | 
			
		||||
		r.Response = sb.String()
 | 
			
		||||
		c.JSON(http.StatusOK, r)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	streamResponse(c, ch)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ChatHandler(c *gin.Context) {
 | 
			
		||||
	loaded.mu.Lock()
 | 
			
		||||
	defer loaded.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	checkpointStart := time.Now()
 | 
			
		||||
 | 
			
		||||
	var req api.ChatRequest
 | 
			
		||||
	err := c.ShouldBindJSON(&req)
 | 
			
		||||
	switch {
 | 
			
		||||
	case errors.Is(err, io.EOF):
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 | 
			
		||||
		return
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// validate the request
 | 
			
		||||
	switch {
 | 
			
		||||
	case req.Model == "":
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
 | 
			
		||||
		return
 | 
			
		||||
	case len(req.Format) > 0 && req.Format != "json":
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionDuration := defaultSessionDuration
 | 
			
		||||
	model, err := load(c, req.Model, req.Options, sessionDuration)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		var pErr *fs.PathError
 | 
			
		||||
		switch {
 | 
			
		||||
		case errors.As(err, &pErr):
 | 
			
		||||
			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 | 
			
		||||
		case errors.Is(err, api.ErrInvalidOpts):
 | 
			
		||||
			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		default:
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// an empty request loads the model
 | 
			
		||||
	if len(req.Messages) == 0 {
 | 
			
		||||
		c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	checkpointLoaded := time.Now()
 | 
			
		||||
 | 
			
		||||
	if req.Template != "" {
 | 
			
		||||
		// override the default model template
 | 
			
		||||
		model.Template = req.Template
 | 
			
		||||
	}
 | 
			
		||||
	prompt, err := model.ChatPrompt(req.Messages)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ch := make(chan any)
 | 
			
		||||
	workDir := c.GetString("workDir")
 | 
			
		||||
 | 
			
		||||
	// TODO: set this duration from the request if specified
 | 
			
		||||
	sessionDuration := defaultSessionDuration
 | 
			
		||||
	if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
 | 
			
		||||
		if errors.Is(err, api.ErrInvalidOpts) {
 | 
			
		||||
			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	checkpointLoaded := time.Now()
 | 
			
		||||
 | 
			
		||||
	prompt := req.Prompt
 | 
			
		||||
	if !req.Raw {
 | 
			
		||||
		prompt, err = model.Prompt(req)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ch := make(chan any)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer close(ch)
 | 
			
		||||
		// an empty request loads the model
 | 
			
		||||
		if req.Prompt == "" && req.Template == "" && req.System == "" {
 | 
			
		||||
			ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fn := func(r llm.PredictResponse) {
 | 
			
		||||
			// Update model expiration
 | 
			
		||||
		fn := func(r api.GenerateResponse) {
 | 
			
		||||
			loaded.expireAt = time.Now().Add(sessionDuration)
 | 
			
		||||
			loaded.expireTimer.Reset(sessionDuration)
 | 
			
		||||
 | 
			
		||||
			resp := api.ChatResponse{
 | 
			
		||||
				Model:     r.Model,
 | 
			
		||||
				CreatedAt: r.CreatedAt,
 | 
			
		||||
				Done:      r.Done,
 | 
			
		||||
				EvalMetrics: api.EvalMetrics{
 | 
			
		||||
					TotalDuration:      r.TotalDuration,
 | 
			
		||||
					LoadDuration:       r.LoadDuration,
 | 
			
		||||
					PromptEvalCount:    r.PromptEvalCount,
 | 
			
		||||
					PromptEvalDuration: r.PromptEvalDuration,
 | 
			
		||||
					EvalCount:          r.EvalCount,
 | 
			
		||||
					EvalDuration:       r.EvalDuration,
 | 
			
		||||
				},
 | 
			
		||||
			r.Model = req.Model
 | 
			
		||||
			r.CreatedAt = time.Now().UTC()
 | 
			
		||||
			if r.Done {
 | 
			
		||||
				r.TotalDuration = time.Since(checkpointStart)
 | 
			
		||||
				r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !r.Done {
 | 
			
		||||
				resp.Message = &api.Message{Role: "assistant", Content: r.Content}
 | 
			
		||||
			if req.Raw {
 | 
			
		||||
				// in raw mode the client must manage history on their own
 | 
			
		||||
				r.Context = nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ch <- resp
 | 
			
		||||
			ch <- r
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Start prediction
 | 
			
		||||
		predictReq := llm.PredictRequest{
 | 
			
		||||
			Model:            model.Name,
 | 
			
		||||
			Prompt:           prompt,
 | 
			
		||||
			Format:           req.Format,
 | 
			
		||||
			CheckpointStart:  checkpointStart,
 | 
			
		||||
			CheckpointLoaded: checkpointLoaded,
 | 
			
		||||
		}
 | 
			
		||||
		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 | 
			
		||||
		if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
 | 
			
		||||
			ch <- gin.H{"error": err.Error()}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if req.Stream != nil && !*req.Stream {
 | 
			
		||||
		// Wait for the channel to close
 | 
			
		||||
		var r api.ChatResponse
 | 
			
		||||
		var sb strings.Builder
 | 
			
		||||
		var response api.GenerateResponse
 | 
			
		||||
		generated := ""
 | 
			
		||||
		for resp := range ch {
 | 
			
		||||
			var ok bool
 | 
			
		||||
			if r, ok = resp.(api.ChatResponse); !ok {
 | 
			
		||||
			if r, ok := resp.(api.GenerateResponse); ok {
 | 
			
		||||
				generated += r.Response
 | 
			
		||||
				response = r
 | 
			
		||||
			} else {
 | 
			
		||||
				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if r.Message != nil {
 | 
			
		||||
				sb.WriteString(r.Message.Content)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		r.Message = &api.Message{Role: "assistant", Content: sb.String()}
 | 
			
		||||
		c.JSON(http.StatusOK, r)
 | 
			
		||||
		response.Response = generated
 | 
			
		||||
		c.JSON(http.StatusOK, response)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -464,18 +281,15 @@ func EmbeddingHandler(c *gin.Context) {
 | 
			
		|||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionDuration := defaultSessionDuration
 | 
			
		||||
	_, err = load(c, req.Model, req.Options, sessionDuration)
 | 
			
		||||
	model, err := GetModel(req.Model)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		var pErr *fs.PathError
 | 
			
		||||
		switch {
 | 
			
		||||
		case errors.As(err, &pErr):
 | 
			
		||||
			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 | 
			
		||||
		case errors.Is(err, api.ErrInvalidOpts):
 | 
			
		||||
			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		default:
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	workDir := c.GetString("workDir")
 | 
			
		||||
	if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
 | 
			
		||||
		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -953,7 +767,6 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 | 
			
		|||
 | 
			
		||||
	r.POST("/api/pull", PullModelHandler)
 | 
			
		||||
	r.POST("/api/generate", GenerateHandler)
 | 
			
		||||
	r.POST("/api/chat", ChatHandler)
 | 
			
		||||
	r.POST("/api/embeddings", EmbeddingHandler)
 | 
			
		||||
	r.POST("/api/create", CreateModelHandler)
 | 
			
		||||
	r.POST("/api/push", PushModelHandler)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue