| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | // openai package provides middleware for partial compatibility with the OpenAI REST API
 | 
					
						
							|  |  |  | package openai | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"bytes" | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 	"encoding/base64" | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	"encoding/json" | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"io" | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 	"log/slog" | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	"math/rand" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/gin-gonic/gin" | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 04:04:17 +08:00
										 |  |  | 	"github.com/ollama/ollama/api" | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	"github.com/ollama/ollama/types/model" | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Error struct { | 
					
						
							|  |  |  | 	Message string      `json:"message"` | 
					
						
							|  |  |  | 	Type    string      `json:"type"` | 
					
						
							|  |  |  | 	Param   interface{} `json:"param"` | 
					
						
							|  |  |  | 	Code    *string     `json:"code"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ErrorResponse struct { | 
					
						
							|  |  |  | 	Error Error `json:"error"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Message struct { | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 	Role      string     `json:"role"` | 
					
						
							|  |  |  | 	Content   any        `json:"content"` | 
					
						
							|  |  |  | 	ToolCalls []ToolCall `json:"tool_calls,omitempty"` | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Choice struct { | 
					
						
							|  |  |  | 	Index        int     `json:"index"` | 
					
						
							|  |  |  | 	Message      Message `json:"message"` | 
					
						
							|  |  |  | 	FinishReason *string `json:"finish_reason"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ChunkChoice struct { | 
					
						
							|  |  |  | 	Index        int     `json:"index"` | 
					
						
							|  |  |  | 	Delta        Message `json:"delta"` | 
					
						
							|  |  |  | 	FinishReason *string `json:"finish_reason"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | type CompleteChunkChoice struct { | 
					
						
							|  |  |  | 	Text         string  `json:"text"` | 
					
						
							|  |  |  | 	Index        int     `json:"index"` | 
					
						
							|  |  |  | 	FinishReason *string `json:"finish_reason"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | type Usage struct { | 
					
						
							|  |  |  | 	PromptTokens     int `json:"prompt_tokens"` | 
					
						
							|  |  |  | 	CompletionTokens int `json:"completion_tokens"` | 
					
						
							|  |  |  | 	TotalTokens      int `json:"total_tokens"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ResponseFormat struct { | 
					
						
							|  |  |  | 	Type string `json:"type"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | type EmbedRequest struct { | 
					
						
							|  |  |  | 	Input any    `json:"input"` | 
					
						
							|  |  |  | 	Model string `json:"model"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | type ChatCompletionRequest struct { | 
					
						
							|  |  |  | 	Model            string          `json:"model"` | 
					
						
							|  |  |  | 	Messages         []Message       `json:"messages"` | 
					
						
							|  |  |  | 	Stream           bool            `json:"stream"` | 
					
						
							|  |  |  | 	MaxTokens        *int            `json:"max_tokens"` | 
					
						
							|  |  |  | 	Seed             *int            `json:"seed"` | 
					
						
							|  |  |  | 	Stop             any             `json:"stop"` | 
					
						
							|  |  |  | 	Temperature      *float64        `json:"temperature"` | 
					
						
							|  |  |  | 	FrequencyPenalty *float64        `json:"frequency_penalty"` | 
					
						
							|  |  |  | 	PresencePenalty  *float64        `json:"presence_penalty_penalty"` | 
					
						
							|  |  |  | 	TopP             *float64        `json:"top_p"` | 
					
						
							|  |  |  | 	ResponseFormat   *ResponseFormat `json:"response_format"` | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 	Tools            []api.Tool      `json:"tools"` | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ChatCompletion struct { | 
					
						
							|  |  |  | 	Id                string   `json:"id"` | 
					
						
							|  |  |  | 	Object            string   `json:"object"` | 
					
						
							|  |  |  | 	Created           int64    `json:"created"` | 
					
						
							|  |  |  | 	Model             string   `json:"model"` | 
					
						
							|  |  |  | 	SystemFingerprint string   `json:"system_fingerprint"` | 
					
						
							|  |  |  | 	Choices           []Choice `json:"choices"` | 
					
						
							|  |  |  | 	Usage             Usage    `json:"usage,omitempty"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ChatCompletionChunk struct { | 
					
						
							|  |  |  | 	Id                string        `json:"id"` | 
					
						
							|  |  |  | 	Object            string        `json:"object"` | 
					
						
							|  |  |  | 	Created           int64         `json:"created"` | 
					
						
							|  |  |  | 	Model             string        `json:"model"` | 
					
						
							|  |  |  | 	SystemFingerprint string        `json:"system_fingerprint"` | 
					
						
							|  |  |  | 	Choices           []ChunkChoice `json:"choices"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
 | 
					
						
							|  |  |  | type CompletionRequest struct { | 
					
						
							|  |  |  | 	Model            string   `json:"model"` | 
					
						
							|  |  |  | 	Prompt           string   `json:"prompt"` | 
					
						
							|  |  |  | 	FrequencyPenalty float32  `json:"frequency_penalty"` | 
					
						
							|  |  |  | 	MaxTokens        *int     `json:"max_tokens"` | 
					
						
							|  |  |  | 	PresencePenalty  float32  `json:"presence_penalty"` | 
					
						
							|  |  |  | 	Seed             *int     `json:"seed"` | 
					
						
							|  |  |  | 	Stop             any      `json:"stop"` | 
					
						
							|  |  |  | 	Stream           bool     `json:"stream"` | 
					
						
							|  |  |  | 	Temperature      *float32 `json:"temperature"` | 
					
						
							|  |  |  | 	TopP             float32  `json:"top_p"` | 
					
						
							| 
									
										
										
										
											2024-07-17 11:50:14 +08:00
										 |  |  | 	Suffix           string   `json:"suffix"` | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type Completion struct { | 
					
						
							|  |  |  | 	Id                string                `json:"id"` | 
					
						
							|  |  |  | 	Object            string                `json:"object"` | 
					
						
							|  |  |  | 	Created           int64                 `json:"created"` | 
					
						
							|  |  |  | 	Model             string                `json:"model"` | 
					
						
							|  |  |  | 	SystemFingerprint string                `json:"system_fingerprint"` | 
					
						
							|  |  |  | 	Choices           []CompleteChunkChoice `json:"choices"` | 
					
						
							|  |  |  | 	Usage             Usage                 `json:"usage,omitempty"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type CompletionChunk struct { | 
					
						
							|  |  |  | 	Id                string                `json:"id"` | 
					
						
							|  |  |  | 	Object            string                `json:"object"` | 
					
						
							|  |  |  | 	Created           int64                 `json:"created"` | 
					
						
							|  |  |  | 	Choices           []CompleteChunkChoice `json:"choices"` | 
					
						
							|  |  |  | 	Model             string                `json:"model"` | 
					
						
							|  |  |  | 	SystemFingerprint string                `json:"system_fingerprint"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | type ToolCall struct { | 
					
						
							|  |  |  | 	ID       string `json:"id"` | 
					
						
							|  |  |  | 	Type     string `json:"type"` | 
					
						
							|  |  |  | 	Function struct { | 
					
						
							|  |  |  | 		Name      string `json:"name"` | 
					
						
							|  |  |  | 		Arguments string `json:"arguments"` | 
					
						
							|  |  |  | 	} `json:"function"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | type Model struct { | 
					
						
							|  |  |  | 	Id      string `json:"id"` | 
					
						
							|  |  |  | 	Object  string `json:"object"` | 
					
						
							|  |  |  | 	Created int64  `json:"created"` | 
					
						
							|  |  |  | 	OwnedBy string `json:"owned_by"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | type Embedding struct { | 
					
						
							|  |  |  | 	Object    string    `json:"object"` | 
					
						
							|  |  |  | 	Embedding []float32 `json:"embedding"` | 
					
						
							|  |  |  | 	Index     int       `json:"index"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | type ListCompletion struct { | 
					
						
							|  |  |  | 	Object string  `json:"object"` | 
					
						
							|  |  |  | 	Data   []Model `json:"data"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | type EmbeddingList struct { | 
					
						
							| 
									
										
										
										
											2024-08-02 06:49:37 +08:00
										 |  |  | 	Object string         `json:"object"` | 
					
						
							|  |  |  | 	Data   []Embedding    `json:"data"` | 
					
						
							|  |  |  | 	Model  string         `json:"model"` | 
					
						
							|  |  |  | 	Usage  EmbeddingUsage `json:"usage,omitempty"` | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type EmbeddingUsage struct { | 
					
						
							|  |  |  | 	PromptTokens int `json:"prompt_tokens"` | 
					
						
							|  |  |  | 	TotalTokens  int `json:"total_tokens"` | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | func NewError(code int, message string) ErrorResponse { | 
					
						
							|  |  |  | 	var etype string | 
					
						
							|  |  |  | 	switch code { | 
					
						
							|  |  |  | 	case http.StatusBadRequest: | 
					
						
							|  |  |  | 		etype = "invalid_request_error" | 
					
						
							|  |  |  | 	case http.StatusNotFound: | 
					
						
							|  |  |  | 		etype = "not_found_error" | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		etype = "api_error" | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return ErrorResponse{Error{Type: etype, Message: message}} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | func toolCallId() string { | 
					
						
							|  |  |  | 	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" | 
					
						
							|  |  |  | 	b := make([]byte, 8) | 
					
						
							|  |  |  | 	for i := range b { | 
					
						
							|  |  |  | 		b[i] = letterBytes[rand.Intn(len(letterBytes))] | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return "call_" + strings.ToLower(string(b)) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 	toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) | 
					
						
							|  |  |  | 	for i, tc := range r.Message.ToolCalls { | 
					
						
							|  |  |  | 		toolCalls[i].ID = toolCallId() | 
					
						
							|  |  |  | 		toolCalls[i].Type = "function" | 
					
						
							|  |  |  | 		toolCalls[i].Function.Name = tc.Function.Name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		args, err := json.Marshal(tc.Function.Arguments) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			slog.Error("could not marshall function arguments to json", "error", err) | 
					
						
							|  |  |  | 			continue | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		toolCalls[i].Function.Arguments = string(args) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	return ChatCompletion{ | 
					
						
							|  |  |  | 		Id:                id, | 
					
						
							|  |  |  | 		Object:            "chat.completion", | 
					
						
							|  |  |  | 		Created:           r.CreatedAt.Unix(), | 
					
						
							|  |  |  | 		Model:             r.Model, | 
					
						
							|  |  |  | 		SystemFingerprint: "fp_ollama", | 
					
						
							|  |  |  | 		Choices: []Choice{{ | 
					
						
							| 
									
										
										
										
											2024-05-12 06:31:41 +08:00
										 |  |  | 			Index:   0, | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 			Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls}, | 
					
						
							| 
									
										
										
										
											2024-05-12 06:31:41 +08:00
										 |  |  | 			FinishReason: func(reason string) *string { | 
					
						
							| 
									
										
										
										
											2024-07-30 04:56:57 +08:00
										 |  |  | 				if len(toolCalls) > 0 { | 
					
						
							|  |  |  | 					reason = "tool_calls" | 
					
						
							|  |  |  | 				} | 
					
						
							| 
									
										
										
										
											2024-05-12 06:31:41 +08:00
										 |  |  | 				if len(reason) > 0 { | 
					
						
							|  |  |  | 					return &reason | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 				return nil | 
					
						
							|  |  |  | 			}(r.DoneReason), | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 		}}, | 
					
						
							|  |  |  | 		Usage: Usage{ | 
					
						
							|  |  |  | 			PromptTokens:     r.PromptEvalCount, | 
					
						
							|  |  |  | 			CompletionTokens: r.EvalCount, | 
					
						
							|  |  |  | 			TotalTokens:      r.PromptEvalCount + r.EvalCount, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { | 
					
						
							|  |  |  | 	return ChatCompletionChunk{ | 
					
						
							|  |  |  | 		Id:                id, | 
					
						
							|  |  |  | 		Object:            "chat.completion.chunk", | 
					
						
							|  |  |  | 		Created:           time.Now().Unix(), | 
					
						
							|  |  |  | 		Model:             r.Model, | 
					
						
							|  |  |  | 		SystemFingerprint: "fp_ollama", | 
					
						
							| 
									
										
										
										
											2024-05-12 06:31:41 +08:00
										 |  |  | 		Choices: []ChunkChoice{{ | 
					
						
							|  |  |  | 			Index: 0, | 
					
						
							|  |  |  | 			Delta: Message{Role: "assistant", Content: r.Message.Content}, | 
					
						
							|  |  |  | 			FinishReason: func(reason string) *string { | 
					
						
							|  |  |  | 				if len(reason) > 0 { | 
					
						
							|  |  |  | 					return &reason | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 				return nil | 
					
						
							|  |  |  | 			}(r.DoneReason), | 
					
						
							|  |  |  | 		}}, | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | func toCompletion(id string, r api.GenerateResponse) Completion { | 
					
						
							|  |  |  | 	return Completion{ | 
					
						
							|  |  |  | 		Id:                id, | 
					
						
							|  |  |  | 		Object:            "text_completion", | 
					
						
							|  |  |  | 		Created:           r.CreatedAt.Unix(), | 
					
						
							|  |  |  | 		Model:             r.Model, | 
					
						
							|  |  |  | 		SystemFingerprint: "fp_ollama", | 
					
						
							|  |  |  | 		Choices: []CompleteChunkChoice{{ | 
					
						
							|  |  |  | 			Text:  r.Response, | 
					
						
							|  |  |  | 			Index: 0, | 
					
						
							|  |  |  | 			FinishReason: func(reason string) *string { | 
					
						
							|  |  |  | 				if len(reason) > 0 { | 
					
						
							|  |  |  | 					return &reason | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 				return nil | 
					
						
							|  |  |  | 			}(r.DoneReason), | 
					
						
							|  |  |  | 		}}, | 
					
						
							|  |  |  | 		Usage: Usage{ | 
					
						
							|  |  |  | 			PromptTokens:     r.PromptEvalCount, | 
					
						
							|  |  |  | 			CompletionTokens: r.EvalCount, | 
					
						
							|  |  |  | 			TotalTokens:      r.PromptEvalCount + r.EvalCount, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { | 
					
						
							|  |  |  | 	return CompletionChunk{ | 
					
						
							|  |  |  | 		Id:                id, | 
					
						
							|  |  |  | 		Object:            "text_completion", | 
					
						
							|  |  |  | 		Created:           time.Now().Unix(), | 
					
						
							|  |  |  | 		Model:             r.Model, | 
					
						
							|  |  |  | 		SystemFingerprint: "fp_ollama", | 
					
						
							|  |  |  | 		Choices: []CompleteChunkChoice{{ | 
					
						
							|  |  |  | 			Text:  r.Response, | 
					
						
							|  |  |  | 			Index: 0, | 
					
						
							|  |  |  | 			FinishReason: func(reason string) *string { | 
					
						
							|  |  |  | 				if len(reason) > 0 { | 
					
						
							|  |  |  | 					return &reason | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 				return nil | 
					
						
							|  |  |  | 			}(r.DoneReason), | 
					
						
							|  |  |  | 		}}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func toListCompletion(r api.ListResponse) ListCompletion { | 
					
						
							|  |  |  | 	var data []Model | 
					
						
							|  |  |  | 	for _, m := range r.Models { | 
					
						
							|  |  |  | 		data = append(data, Model{ | 
					
						
							|  |  |  | 			Id:      m.Name, | 
					
						
							|  |  |  | 			Object:  "model", | 
					
						
							|  |  |  | 			Created: m.ModifiedAt.Unix(), | 
					
						
							|  |  |  | 			OwnedBy: model.ParseName(m.Name).Namespace, | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return ListCompletion{ | 
					
						
							|  |  |  | 		Object: "list", | 
					
						
							|  |  |  | 		Data:   data, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { | 
					
						
							|  |  |  | 	if r.Embeddings != nil { | 
					
						
							|  |  |  | 		var data []Embedding | 
					
						
							|  |  |  | 		for i, e := range r.Embeddings { | 
					
						
							|  |  |  | 			data = append(data, Embedding{ | 
					
						
							|  |  |  | 				Object:    "embedding", | 
					
						
							|  |  |  | 				Embedding: e, | 
					
						
							|  |  |  | 				Index:     i, | 
					
						
							|  |  |  | 			}) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return EmbeddingList{ | 
					
						
							|  |  |  | 			Object: "list", | 
					
						
							|  |  |  | 			Data:   data, | 
					
						
							|  |  |  | 			Model:  model, | 
					
						
							| 
									
										
										
										
											2024-08-02 06:49:37 +08:00
										 |  |  | 			Usage: EmbeddingUsage{ | 
					
						
							|  |  |  | 				PromptTokens: r.PromptEvalCount, | 
					
						
							|  |  |  | 				TotalTokens:  r.PromptEvalCount, | 
					
						
							|  |  |  | 			}, | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return EmbeddingList{} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func toModel(r api.ShowResponse, m string) Model { | 
					
						
							|  |  |  | 	return Model{ | 
					
						
							|  |  |  | 		Id:      m, | 
					
						
							|  |  |  | 		Object:  "model", | 
					
						
							|  |  |  | 		Created: r.ModifiedAt.Unix(), | 
					
						
							|  |  |  | 		OwnedBy: model.ParseName(m).Namespace, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	var messages []api.Message | 
					
						
							|  |  |  | 	for _, msg := range r.Messages { | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 		switch content := msg.Content.(type) { | 
					
						
							|  |  |  | 		case string: | 
					
						
							|  |  |  | 			messages = append(messages, api.Message{Role: msg.Role, Content: content}) | 
					
						
							|  |  |  | 		case []any: | 
					
						
							|  |  |  | 			for _, c := range content { | 
					
						
							|  |  |  | 				data, ok := c.(map[string]any) | 
					
						
							|  |  |  | 				if !ok { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 					return nil, errors.New("invalid message format") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 				} | 
					
						
							|  |  |  | 				switch data["type"] { | 
					
						
							|  |  |  | 				case "text": | 
					
						
							|  |  |  | 					text, ok := data["text"].(string) | 
					
						
							|  |  |  | 					if !ok { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 						return nil, errors.New("invalid message format") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 					} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:19:20 +08:00
										 |  |  | 					messages = append(messages, api.Message{Role: msg.Role, Content: text}) | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 				case "image_url": | 
					
						
							|  |  |  | 					var url string | 
					
						
							|  |  |  | 					if urlMap, ok := data["image_url"].(map[string]any); ok { | 
					
						
							|  |  |  | 						if url, ok = urlMap["url"].(string); !ok { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 							return nil, errors.New("invalid message format") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 						} | 
					
						
							|  |  |  | 					} else { | 
					
						
							|  |  |  | 						if url, ok = data["image_url"].(string); !ok { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 							return nil, errors.New("invalid message format") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 						} | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					types := []string{"jpeg", "jpg", "png"} | 
					
						
							|  |  |  | 					valid := false | 
					
						
							|  |  |  | 					for _, t := range types { | 
					
						
							|  |  |  | 						prefix := "data:image/" + t + ";base64," | 
					
						
							|  |  |  | 						if strings.HasPrefix(url, prefix) { | 
					
						
							|  |  |  | 							url = strings.TrimPrefix(url, prefix) | 
					
						
							|  |  |  | 							valid = true | 
					
						
							|  |  |  | 							break | 
					
						
							|  |  |  | 						} | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					if !valid { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 						return nil, errors.New("invalid image input") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					img, err := base64.StdEncoding.DecodeString(url) | 
					
						
							|  |  |  | 					if err != nil { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 						return nil, errors.New("invalid message format") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 					} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:19:20 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 					messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 				default: | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 					return nil, errors.New("invalid message format") | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		default: | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 			if msg.ToolCalls == nil { | 
					
						
							|  |  |  | 				return nil, fmt.Errorf("invalid message content type: %T", content) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			toolCalls := make([]api.ToolCall, len(msg.ToolCalls)) | 
					
						
							|  |  |  | 			for i, tc := range msg.ToolCalls { | 
					
						
							|  |  |  | 				toolCalls[i].Function.Name = tc.Function.Name | 
					
						
							|  |  |  | 				err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments) | 
					
						
							|  |  |  | 				if err != nil { | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 					return nil, errors.New("invalid tool call arguments") | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls}) | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	options := make(map[string]interface{}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch stop := r.Stop.(type) { | 
					
						
							|  |  |  | 	case string: | 
					
						
							|  |  |  | 		options["stop"] = []string{stop} | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 	case []any: | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 		var stops []string | 
					
						
							|  |  |  | 		for _, s := range stop { | 
					
						
							|  |  |  | 			if str, ok := s.(string); ok { | 
					
						
							|  |  |  | 				stops = append(stops, str) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		options["stop"] = stops | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.MaxTokens != nil { | 
					
						
							|  |  |  | 		options["num_predict"] = *r.MaxTokens | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.Temperature != nil { | 
					
						
							|  |  |  | 		options["temperature"] = *r.Temperature * 2.0 | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		options["temperature"] = 1.0 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.Seed != nil { | 
					
						
							|  |  |  | 		options["seed"] = *r.Seed | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.FrequencyPenalty != nil { | 
					
						
							|  |  |  | 		options["frequency_penalty"] = *r.FrequencyPenalty * 2.0 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.PresencePenalty != nil { | 
					
						
							|  |  |  | 		options["presence_penalty"] = *r.PresencePenalty * 2.0 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.TopP != nil { | 
					
						
							|  |  |  | 		options["top_p"] = *r.TopP | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		options["top_p"] = 1.0 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var format string | 
					
						
							|  |  |  | 	if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { | 
					
						
							|  |  |  | 		format = "json" | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 	return &api.ChatRequest{ | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 		Model:    r.Model, | 
					
						
							|  |  |  | 		Messages: messages, | 
					
						
							|  |  |  | 		Format:   format, | 
					
						
							|  |  |  | 		Options:  options, | 
					
						
							|  |  |  | 		Stream:   &r.Stream, | 
					
						
							| 
									
										
										
										
											2024-07-17 11:52:59 +08:00
										 |  |  | 		Tools:    r.Tools, | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 	}, nil | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { | 
					
						
							|  |  |  | 	options := make(map[string]any) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch stop := r.Stop.(type) { | 
					
						
							|  |  |  | 	case string: | 
					
						
							|  |  |  | 		options["stop"] = []string{stop} | 
					
						
							| 
									
										
										
										
											2024-07-10 05:01:26 +08:00
										 |  |  | 	case []any: | 
					
						
							|  |  |  | 		var stops []string | 
					
						
							|  |  |  | 		for _, s := range stop { | 
					
						
							|  |  |  | 			if str, ok := s.(string); ok { | 
					
						
							|  |  |  | 				stops = append(stops, str) | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s) | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2024-07-10 05:01:26 +08:00
										 |  |  | 		options["stop"] = stops | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.MaxTokens != nil { | 
					
						
							|  |  |  | 		options["num_predict"] = *r.MaxTokens | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.Temperature != nil { | 
					
						
							|  |  |  | 		options["temperature"] = *r.Temperature * 2.0 | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		options["temperature"] = 1.0 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.Seed != nil { | 
					
						
							|  |  |  | 		options["seed"] = *r.Seed | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	options["frequency_penalty"] = r.FrequencyPenalty * 2.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	options["presence_penalty"] = r.PresencePenalty * 2.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if r.TopP != 0.0 { | 
					
						
							|  |  |  | 		options["top_p"] = r.TopP | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		options["top_p"] = 1.0 | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return api.GenerateRequest{ | 
					
						
							|  |  |  | 		Model:   r.Model, | 
					
						
							|  |  |  | 		Prompt:  r.Prompt, | 
					
						
							|  |  |  | 		Options: options, | 
					
						
							|  |  |  | 		Stream:  &r.Stream, | 
					
						
							| 
									
										
										
										
											2024-07-17 11:50:14 +08:00
										 |  |  | 		Suffix:  r.Suffix, | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 	}, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | type BaseWriter struct { | 
					
						
							|  |  |  | 	gin.ResponseWriter | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ChatWriter struct { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	stream bool | 
					
						
							|  |  |  | 	id     string | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	BaseWriter | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | type CompleteWriter struct { | 
					
						
							|  |  |  | 	stream bool | 
					
						
							|  |  |  | 	id     string | 
					
						
							|  |  |  | 	BaseWriter | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | type ListWriter struct { | 
					
						
							|  |  |  | 	BaseWriter | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type RetrieveWriter struct { | 
					
						
							|  |  |  | 	BaseWriter | 
					
						
							|  |  |  | 	model string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | type EmbedWriter struct { | 
					
						
							|  |  |  | 	BaseWriter | 
					
						
							|  |  |  | 	model string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func (w *BaseWriter) writeError(code int, data []byte) (int, error) { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	var serr api.StatusError | 
					
						
							|  |  |  | 	err := json.Unmarshal(data, &serr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	w.ResponseWriter.Header().Set("Content-Type", "application/json") | 
					
						
							|  |  |  | 	err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error())) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return len(data), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func (w *ChatWriter) writeResponse(data []byte) (int, error) { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	var chatResponse api.ChatResponse | 
					
						
							|  |  |  | 	err := json.Unmarshal(data, &chatResponse) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// chat chunk
 | 
					
						
							|  |  |  | 	if w.stream { | 
					
						
							|  |  |  | 		d, err := json.Marshal(toChunk(w.id, chatResponse)) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return 0, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") | 
					
						
							|  |  |  | 		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return 0, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if chatResponse.Done { | 
					
						
							|  |  |  | 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return 0, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return len(data), nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// chat completion
 | 
					
						
							|  |  |  | 	w.ResponseWriter.Header().Set("Content-Type", "application/json") | 
					
						
							|  |  |  | 	err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return len(data), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func (w *ChatWriter) Write(data []byte) (int, error) { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	code := w.ResponseWriter.Status() | 
					
						
							|  |  |  | 	if code != http.StatusOK { | 
					
						
							|  |  |  | 		return w.writeError(code, data) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return w.writeResponse(data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | func (w *CompleteWriter) writeResponse(data []byte) (int, error) { | 
					
						
							|  |  |  | 	var generateResponse api.GenerateResponse | 
					
						
							|  |  |  | 	err := json.Unmarshal(data, &generateResponse) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// completion chunk
 | 
					
						
							|  |  |  | 	if w.stream { | 
					
						
							|  |  |  | 		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse)) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return 0, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") | 
					
						
							|  |  |  | 		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return 0, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if generateResponse.Done { | 
					
						
							|  |  |  | 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return 0, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return len(data), nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// completion
 | 
					
						
							|  |  |  | 	w.ResponseWriter.Header().Set("Content-Type", "application/json") | 
					
						
							|  |  |  | 	err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return len(data), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (w *CompleteWriter) Write(data []byte) (int, error) { | 
					
						
							|  |  |  | 	code := w.ResponseWriter.Status() | 
					
						
							|  |  |  | 	if code != http.StatusOK { | 
					
						
							|  |  |  | 		return w.writeError(code, data) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return w.writeResponse(data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func (w *ListWriter) writeResponse(data []byte) (int, error) { | 
					
						
							|  |  |  | 	var listResponse api.ListResponse | 
					
						
							|  |  |  | 	err := json.Unmarshal(data, &listResponse) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	w.ResponseWriter.Header().Set("Content-Type", "application/json") | 
					
						
							|  |  |  | 	err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return len(data), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (w *ListWriter) Write(data []byte) (int, error) { | 
					
						
							|  |  |  | 	code := w.ResponseWriter.Status() | 
					
						
							|  |  |  | 	if code != http.StatusOK { | 
					
						
							|  |  |  | 		return w.writeError(code, data) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return w.writeResponse(data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { | 
					
						
							|  |  |  | 	var showResponse api.ShowResponse | 
					
						
							|  |  |  | 	err := json.Unmarshal(data, &showResponse) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// retrieve completion
 | 
					
						
							|  |  |  | 	w.ResponseWriter.Header().Set("Content-Type", "application/json") | 
					
						
							|  |  |  | 	err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return len(data), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (w *RetrieveWriter) Write(data []byte) (int, error) { | 
					
						
							|  |  |  | 	code := w.ResponseWriter.Status() | 
					
						
							|  |  |  | 	if code != http.StatusOK { | 
					
						
							|  |  |  | 		return w.writeError(code, data) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return w.writeResponse(data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | func (w *EmbedWriter) writeResponse(data []byte) (int, error) { | 
					
						
							|  |  |  | 	var embedResponse api.EmbedResponse | 
					
						
							|  |  |  | 	err := json.Unmarshal(data, &embedResponse) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	w.ResponseWriter.Header().Set("Content-Type", "application/json") | 
					
						
							|  |  |  | 	err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return 0, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return len(data), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (w *EmbedWriter) Write(data []byte) (int, error) { | 
					
						
							|  |  |  | 	code := w.ResponseWriter.Status() | 
					
						
							|  |  |  | 	if code != http.StatusOK { | 
					
						
							|  |  |  | 		return w.writeError(code, data) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return w.writeResponse(data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func ListMiddleware() gin.HandlerFunc { | 
					
						
							|  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  | 		w := &ListWriter{ | 
					
						
							|  |  |  | 			BaseWriter: BaseWriter{ResponseWriter: c.Writer}, | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Writer = w | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Next() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func RetrieveMiddleware() gin.HandlerFunc { | 
					
						
							|  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  | 		var b bytes.Buffer | 
					
						
							|  |  |  | 		if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Request.Body = io.NopCloser(&b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// response writer
 | 
					
						
							|  |  |  | 		w := &RetrieveWriter{ | 
					
						
							|  |  |  | 			BaseWriter: BaseWriter{ResponseWriter: c.Writer}, | 
					
						
							|  |  |  | 			model:      c.Param("model"), | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Writer = w | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Next() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | func CompletionsMiddleware() gin.HandlerFunc { | 
					
						
							|  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  | 		var req CompletionRequest | 
					
						
							|  |  |  | 		err := c.ShouldBindJSON(&req) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		var b bytes.Buffer | 
					
						
							|  |  |  | 		genReq, err := fromCompleteRequest(req) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if err := json.NewEncoder(&b).Encode(genReq); err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Request.Body = io.NopCloser(&b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		w := &CompleteWriter{ | 
					
						
							|  |  |  | 			BaseWriter: BaseWriter{ResponseWriter: c.Writer}, | 
					
						
							|  |  |  | 			stream:     req.Stream, | 
					
						
							|  |  |  | 			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)), | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | 		c.Writer = w | 
					
						
							|  |  |  | 		c.Next() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func EmbeddingsMiddleware() gin.HandlerFunc { | 
					
						
							|  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  | 		var req EmbedRequest | 
					
						
							|  |  |  | 		err := c.ShouldBindJSON(&req) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if req.Input == "" { | 
					
						
							|  |  |  | 			req.Input = []string{""} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if req.Input == nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if v, ok := req.Input.([]any); ok && len(v) == 0 { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		var b bytes.Buffer | 
					
						
							|  |  |  | 		if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Request.Body = io.NopCloser(&b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		w := &EmbedWriter{ | 
					
						
							|  |  |  | 			BaseWriter: BaseWriter{ResponseWriter: c.Writer}, | 
					
						
							|  |  |  | 			model:      req.Model, | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 		c.Writer = w | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Next() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | func ChatMiddleware() gin.HandlerFunc { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  | 		var req ChatCompletionRequest | 
					
						
							|  |  |  | 		err := c.ShouldBindJSON(&req) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if len(req.Messages) == 0 { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		var b bytes.Buffer | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		chatReq, err := fromChatRequest(req) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 			return | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if err := json.NewEncoder(&b).Encode(chatReq); err != nil { | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Request.Body = io.NopCloser(&b) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 		w := &ChatWriter{ | 
					
						
							|  |  |  | 			BaseWriter: BaseWriter{ResponseWriter: c.Writer}, | 
					
						
							|  |  |  | 			stream:     req.Stream, | 
					
						
							|  |  |  | 			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), | 
					
						
							| 
									
										
										
										
											2024-02-08 06:24:29 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Writer = w | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		c.Next() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |