mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			914 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			914 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Go
		
	
	
	
| // openai package provides middleware for partial compatibility with the OpenAI REST API
 | |
| package openai
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"log/slog"
 | |
| 	"math/rand"
 | |
| 	"net/http"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 
 | |
| 	"github.com/ollama/ollama/api"
 | |
| 	"github.com/ollama/ollama/types/model"
 | |
| )
 | |
| 
 | |
| type Error struct {
 | |
| 	Message string      `json:"message"`
 | |
| 	Type    string      `json:"type"`
 | |
| 	Param   interface{} `json:"param"`
 | |
| 	Code    *string     `json:"code"`
 | |
| }
 | |
| 
 | |
| type ErrorResponse struct {
 | |
| 	Error Error `json:"error"`
 | |
| }
 | |
| 
 | |
| type Message struct {
 | |
| 	Role      string     `json:"role"`
 | |
| 	Content   any        `json:"content"`
 | |
| 	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
 | |
| }
 | |
| 
 | |
| type Choice struct {
 | |
| 	Index        int     `json:"index"`
 | |
| 	Message      Message `json:"message"`
 | |
| 	FinishReason *string `json:"finish_reason"`
 | |
| }
 | |
| 
 | |
| type ChunkChoice struct {
 | |
| 	Index        int     `json:"index"`
 | |
| 	Delta        Message `json:"delta"`
 | |
| 	FinishReason *string `json:"finish_reason"`
 | |
| }
 | |
| 
 | |
| type CompleteChunkChoice struct {
 | |
| 	Text         string  `json:"text"`
 | |
| 	Index        int     `json:"index"`
 | |
| 	FinishReason *string `json:"finish_reason"`
 | |
| }
 | |
| 
 | |
| type Usage struct {
 | |
| 	PromptTokens     int `json:"prompt_tokens"`
 | |
| 	CompletionTokens int `json:"completion_tokens"`
 | |
| 	TotalTokens      int `json:"total_tokens"`
 | |
| }
 | |
| 
 | |
| type ResponseFormat struct {
 | |
| 	Type string `json:"type"`
 | |
| }
 | |
| 
 | |
| type EmbedRequest struct {
 | |
| 	Input any    `json:"input"`
 | |
| 	Model string `json:"model"`
 | |
| }
 | |
| 
 | |
| type ChatCompletionRequest struct {
 | |
| 	Model            string          `json:"model"`
 | |
| 	Messages         []Message       `json:"messages"`
 | |
| 	Stream           bool            `json:"stream"`
 | |
| 	MaxTokens        *int            `json:"max_tokens"`
 | |
| 	Seed             *int            `json:"seed"`
 | |
| 	Stop             any             `json:"stop"`
 | |
| 	Temperature      *float64        `json:"temperature"`
 | |
| 	FrequencyPenalty *float64        `json:"frequency_penalty"`
 | |
| 	PresencePenalty  *float64        `json:"presence_penalty_penalty"`
 | |
| 	TopP             *float64        `json:"top_p"`
 | |
| 	ResponseFormat   *ResponseFormat `json:"response_format"`
 | |
| 	Tools            []api.Tool      `json:"tools"`
 | |
| }
 | |
| 
 | |
| type ChatCompletion struct {
 | |
| 	Id                string   `json:"id"`
 | |
| 	Object            string   `json:"object"`
 | |
| 	Created           int64    `json:"created"`
 | |
| 	Model             string   `json:"model"`
 | |
| 	SystemFingerprint string   `json:"system_fingerprint"`
 | |
| 	Choices           []Choice `json:"choices"`
 | |
| 	Usage             Usage    `json:"usage,omitempty"`
 | |
| }
 | |
| 
 | |
| type ChatCompletionChunk struct {
 | |
| 	Id                string        `json:"id"`
 | |
| 	Object            string        `json:"object"`
 | |
| 	Created           int64         `json:"created"`
 | |
| 	Model             string        `json:"model"`
 | |
| 	SystemFingerprint string        `json:"system_fingerprint"`
 | |
| 	Choices           []ChunkChoice `json:"choices"`
 | |
| }
 | |
| 
 | |
| // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
 | |
| type CompletionRequest struct {
 | |
| 	Model            string   `json:"model"`
 | |
| 	Prompt           string   `json:"prompt"`
 | |
| 	FrequencyPenalty float32  `json:"frequency_penalty"`
 | |
| 	MaxTokens        *int     `json:"max_tokens"`
 | |
| 	PresencePenalty  float32  `json:"presence_penalty"`
 | |
| 	Seed             *int     `json:"seed"`
 | |
| 	Stop             any      `json:"stop"`
 | |
| 	Stream           bool     `json:"stream"`
 | |
| 	Temperature      *float32 `json:"temperature"`
 | |
| 	TopP             float32  `json:"top_p"`
 | |
| 	Suffix           string   `json:"suffix"`
 | |
| }
 | |
| 
 | |
| type Completion struct {
 | |
| 	Id                string                `json:"id"`
 | |
| 	Object            string                `json:"object"`
 | |
| 	Created           int64                 `json:"created"`
 | |
| 	Model             string                `json:"model"`
 | |
| 	SystemFingerprint string                `json:"system_fingerprint"`
 | |
| 	Choices           []CompleteChunkChoice `json:"choices"`
 | |
| 	Usage             Usage                 `json:"usage,omitempty"`
 | |
| }
 | |
| 
 | |
| type CompletionChunk struct {
 | |
| 	Id                string                `json:"id"`
 | |
| 	Object            string                `json:"object"`
 | |
| 	Created           int64                 `json:"created"`
 | |
| 	Choices           []CompleteChunkChoice `json:"choices"`
 | |
| 	Model             string                `json:"model"`
 | |
| 	SystemFingerprint string                `json:"system_fingerprint"`
 | |
| }
 | |
| 
 | |
| type ToolCall struct {
 | |
| 	ID       string `json:"id"`
 | |
| 	Type     string `json:"type"`
 | |
| 	Function struct {
 | |
| 		Name      string `json:"name"`
 | |
| 		Arguments string `json:"arguments"`
 | |
| 	} `json:"function"`
 | |
| }
 | |
| 
 | |
| type Model struct {
 | |
| 	Id      string `json:"id"`
 | |
| 	Object  string `json:"object"`
 | |
| 	Created int64  `json:"created"`
 | |
| 	OwnedBy string `json:"owned_by"`
 | |
| }
 | |
| 
 | |
| type Embedding struct {
 | |
| 	Object    string    `json:"object"`
 | |
| 	Embedding []float32 `json:"embedding"`
 | |
| 	Index     int       `json:"index"`
 | |
| }
 | |
| 
 | |
| type ListCompletion struct {
 | |
| 	Object string  `json:"object"`
 | |
| 	Data   []Model `json:"data"`
 | |
| }
 | |
| 
 | |
| type EmbeddingList struct {
 | |
| 	Object string         `json:"object"`
 | |
| 	Data   []Embedding    `json:"data"`
 | |
| 	Model  string         `json:"model"`
 | |
| 	Usage  EmbeddingUsage `json:"usage,omitempty"`
 | |
| }
 | |
| 
 | |
| type EmbeddingUsage struct {
 | |
| 	PromptTokens int `json:"prompt_tokens"`
 | |
| 	TotalTokens  int `json:"total_tokens"`
 | |
| }
 | |
| 
 | |
| func NewError(code int, message string) ErrorResponse {
 | |
| 	var etype string
 | |
| 	switch code {
 | |
| 	case http.StatusBadRequest:
 | |
| 		etype = "invalid_request_error"
 | |
| 	case http.StatusNotFound:
 | |
| 		etype = "not_found_error"
 | |
| 	default:
 | |
| 		etype = "api_error"
 | |
| 	}
 | |
| 
 | |
| 	return ErrorResponse{Error{Type: etype, Message: message}}
 | |
| }
 | |
| 
 | |
| func toolCallId() string {
 | |
| 	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
 | |
| 	b := make([]byte, 8)
 | |
| 	for i := range b {
 | |
| 		b[i] = letterBytes[rand.Intn(len(letterBytes))]
 | |
| 	}
 | |
| 	return "call_" + strings.ToLower(string(b))
 | |
| }
 | |
| 
 | |
| func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 | |
| 	toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
 | |
| 	for i, tc := range r.Message.ToolCalls {
 | |
| 		toolCalls[i].ID = toolCallId()
 | |
| 		toolCalls[i].Type = "function"
 | |
| 		toolCalls[i].Function.Name = tc.Function.Name
 | |
| 
 | |
| 		args, err := json.Marshal(tc.Function.Arguments)
 | |
| 		if err != nil {
 | |
| 			slog.Error("could not marshall function arguments to json", "error", err)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		toolCalls[i].Function.Arguments = string(args)
 | |
| 	}
 | |
| 
 | |
| 	return ChatCompletion{
 | |
| 		Id:                id,
 | |
| 		Object:            "chat.completion",
 | |
| 		Created:           r.CreatedAt.Unix(),
 | |
| 		Model:             r.Model,
 | |
| 		SystemFingerprint: "fp_ollama",
 | |
| 		Choices: []Choice{{
 | |
| 			Index:   0,
 | |
| 			Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
 | |
| 			FinishReason: func(reason string) *string {
 | |
| 				if len(toolCalls) > 0 {
 | |
| 					reason = "tool_calls"
 | |
| 				}
 | |
| 				if len(reason) > 0 {
 | |
| 					return &reason
 | |
| 				}
 | |
| 				return nil
 | |
| 			}(r.DoneReason),
 | |
| 		}},
 | |
| 		Usage: Usage{
 | |
| 			PromptTokens:     r.PromptEvalCount,
 | |
| 			CompletionTokens: r.EvalCount,
 | |
| 			TotalTokens:      r.PromptEvalCount + r.EvalCount,
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 | |
| 	return ChatCompletionChunk{
 | |
| 		Id:                id,
 | |
| 		Object:            "chat.completion.chunk",
 | |
| 		Created:           time.Now().Unix(),
 | |
| 		Model:             r.Model,
 | |
| 		SystemFingerprint: "fp_ollama",
 | |
| 		Choices: []ChunkChoice{{
 | |
| 			Index: 0,
 | |
| 			Delta: Message{Role: "assistant", Content: r.Message.Content},
 | |
| 			FinishReason: func(reason string) *string {
 | |
| 				if len(reason) > 0 {
 | |
| 					return &reason
 | |
| 				}
 | |
| 				return nil
 | |
| 			}(r.DoneReason),
 | |
| 		}},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func toCompletion(id string, r api.GenerateResponse) Completion {
 | |
| 	return Completion{
 | |
| 		Id:                id,
 | |
| 		Object:            "text_completion",
 | |
| 		Created:           r.CreatedAt.Unix(),
 | |
| 		Model:             r.Model,
 | |
| 		SystemFingerprint: "fp_ollama",
 | |
| 		Choices: []CompleteChunkChoice{{
 | |
| 			Text:  r.Response,
 | |
| 			Index: 0,
 | |
| 			FinishReason: func(reason string) *string {
 | |
| 				if len(reason) > 0 {
 | |
| 					return &reason
 | |
| 				}
 | |
| 				return nil
 | |
| 			}(r.DoneReason),
 | |
| 		}},
 | |
| 		Usage: Usage{
 | |
| 			PromptTokens:     r.PromptEvalCount,
 | |
| 			CompletionTokens: r.EvalCount,
 | |
| 			TotalTokens:      r.PromptEvalCount + r.EvalCount,
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
 | |
| 	return CompletionChunk{
 | |
| 		Id:                id,
 | |
| 		Object:            "text_completion",
 | |
| 		Created:           time.Now().Unix(),
 | |
| 		Model:             r.Model,
 | |
| 		SystemFingerprint: "fp_ollama",
 | |
| 		Choices: []CompleteChunkChoice{{
 | |
| 			Text:  r.Response,
 | |
| 			Index: 0,
 | |
| 			FinishReason: func(reason string) *string {
 | |
| 				if len(reason) > 0 {
 | |
| 					return &reason
 | |
| 				}
 | |
| 				return nil
 | |
| 			}(r.DoneReason),
 | |
| 		}},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func toListCompletion(r api.ListResponse) ListCompletion {
 | |
| 	var data []Model
 | |
| 	for _, m := range r.Models {
 | |
| 		data = append(data, Model{
 | |
| 			Id:      m.Name,
 | |
| 			Object:  "model",
 | |
| 			Created: m.ModifiedAt.Unix(),
 | |
| 			OwnedBy: model.ParseName(m.Name).Namespace,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	return ListCompletion{
 | |
| 		Object: "list",
 | |
| 		Data:   data,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
 | |
| 	if r.Embeddings != nil {
 | |
| 		var data []Embedding
 | |
| 		for i, e := range r.Embeddings {
 | |
| 			data = append(data, Embedding{
 | |
| 				Object:    "embedding",
 | |
| 				Embedding: e,
 | |
| 				Index:     i,
 | |
| 			})
 | |
| 		}
 | |
| 
 | |
| 		return EmbeddingList{
 | |
| 			Object: "list",
 | |
| 			Data:   data,
 | |
| 			Model:  model,
 | |
| 			Usage: EmbeddingUsage{
 | |
| 				PromptTokens: r.PromptEvalCount,
 | |
| 				TotalTokens:  r.PromptEvalCount,
 | |
| 			},
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return EmbeddingList{}
 | |
| }
 | |
| 
 | |
| func toModel(r api.ShowResponse, m string) Model {
 | |
| 	return Model{
 | |
| 		Id:      m,
 | |
| 		Object:  "model",
 | |
| 		Created: r.ModifiedAt.Unix(),
 | |
| 		OwnedBy: model.ParseName(m).Namespace,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 | |
| 	var messages []api.Message
 | |
| 	for _, msg := range r.Messages {
 | |
| 		switch content := msg.Content.(type) {
 | |
| 		case string:
 | |
| 			messages = append(messages, api.Message{Role: msg.Role, Content: content})
 | |
| 		case []any:
 | |
| 			for _, c := range content {
 | |
| 				data, ok := c.(map[string]any)
 | |
| 				if !ok {
 | |
| 					return nil, errors.New("invalid message format")
 | |
| 				}
 | |
| 				switch data["type"] {
 | |
| 				case "text":
 | |
| 					text, ok := data["text"].(string)
 | |
| 					if !ok {
 | |
| 						return nil, errors.New("invalid message format")
 | |
| 					}
 | |
| 					messages = append(messages, api.Message{Role: msg.Role, Content: text})
 | |
| 				case "image_url":
 | |
| 					var url string
 | |
| 					if urlMap, ok := data["image_url"].(map[string]any); ok {
 | |
| 						if url, ok = urlMap["url"].(string); !ok {
 | |
| 							return nil, errors.New("invalid message format")
 | |
| 						}
 | |
| 					} else {
 | |
| 						if url, ok = data["image_url"].(string); !ok {
 | |
| 							return nil, errors.New("invalid message format")
 | |
| 						}
 | |
| 					}
 | |
| 
 | |
| 					types := []string{"jpeg", "jpg", "png"}
 | |
| 					valid := false
 | |
| 					for _, t := range types {
 | |
| 						prefix := "data:image/" + t + ";base64,"
 | |
| 						if strings.HasPrefix(url, prefix) {
 | |
| 							url = strings.TrimPrefix(url, prefix)
 | |
| 							valid = true
 | |
| 							break
 | |
| 						}
 | |
| 					}
 | |
| 
 | |
| 					if !valid {
 | |
| 						return nil, errors.New("invalid image input")
 | |
| 					}
 | |
| 
 | |
| 					img, err := base64.StdEncoding.DecodeString(url)
 | |
| 					if err != nil {
 | |
| 						return nil, errors.New("invalid message format")
 | |
| 					}
 | |
| 
 | |
| 					messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
 | |
| 				default:
 | |
| 					return nil, errors.New("invalid message format")
 | |
| 				}
 | |
| 			}
 | |
| 		default:
 | |
| 			if msg.ToolCalls == nil {
 | |
| 				return nil, fmt.Errorf("invalid message content type: %T", content)
 | |
| 			}
 | |
| 
 | |
| 			toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
 | |
| 			for i, tc := range msg.ToolCalls {
 | |
| 				toolCalls[i].Function.Name = tc.Function.Name
 | |
| 				err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
 | |
| 				if err != nil {
 | |
| 					return nil, errors.New("invalid tool call arguments")
 | |
| 				}
 | |
| 			}
 | |
| 			messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	options := make(map[string]interface{})
 | |
| 
 | |
| 	switch stop := r.Stop.(type) {
 | |
| 	case string:
 | |
| 		options["stop"] = []string{stop}
 | |
| 	case []any:
 | |
| 		var stops []string
 | |
| 		for _, s := range stop {
 | |
| 			if str, ok := s.(string); ok {
 | |
| 				stops = append(stops, str)
 | |
| 			}
 | |
| 		}
 | |
| 		options["stop"] = stops
 | |
| 	}
 | |
| 
 | |
| 	if r.MaxTokens != nil {
 | |
| 		options["num_predict"] = *r.MaxTokens
 | |
| 	}
 | |
| 
 | |
| 	if r.Temperature != nil {
 | |
| 		options["temperature"] = *r.Temperature * 2.0
 | |
| 	} else {
 | |
| 		options["temperature"] = 1.0
 | |
| 	}
 | |
| 
 | |
| 	if r.Seed != nil {
 | |
| 		options["seed"] = *r.Seed
 | |
| 	}
 | |
| 
 | |
| 	if r.FrequencyPenalty != nil {
 | |
| 		options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
 | |
| 	}
 | |
| 
 | |
| 	if r.PresencePenalty != nil {
 | |
| 		options["presence_penalty"] = *r.PresencePenalty * 2.0
 | |
| 	}
 | |
| 
 | |
| 	if r.TopP != nil {
 | |
| 		options["top_p"] = *r.TopP
 | |
| 	} else {
 | |
| 		options["top_p"] = 1.0
 | |
| 	}
 | |
| 
 | |
| 	var format string
 | |
| 	if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
 | |
| 		format = "json"
 | |
| 	}
 | |
| 
 | |
| 	return &api.ChatRequest{
 | |
| 		Model:    r.Model,
 | |
| 		Messages: messages,
 | |
| 		Format:   format,
 | |
| 		Options:  options,
 | |
| 		Stream:   &r.Stream,
 | |
| 		Tools:    r.Tools,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
 | |
| 	options := make(map[string]any)
 | |
| 
 | |
| 	switch stop := r.Stop.(type) {
 | |
| 	case string:
 | |
| 		options["stop"] = []string{stop}
 | |
| 	case []any:
 | |
| 		var stops []string
 | |
| 		for _, s := range stop {
 | |
| 			if str, ok := s.(string); ok {
 | |
| 				stops = append(stops, str)
 | |
| 			} else {
 | |
| 				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
 | |
| 			}
 | |
| 		}
 | |
| 		options["stop"] = stops
 | |
| 	}
 | |
| 
 | |
| 	if r.MaxTokens != nil {
 | |
| 		options["num_predict"] = *r.MaxTokens
 | |
| 	}
 | |
| 
 | |
| 	if r.Temperature != nil {
 | |
| 		options["temperature"] = *r.Temperature * 2.0
 | |
| 	} else {
 | |
| 		options["temperature"] = 1.0
 | |
| 	}
 | |
| 
 | |
| 	if r.Seed != nil {
 | |
| 		options["seed"] = *r.Seed
 | |
| 	}
 | |
| 
 | |
| 	options["frequency_penalty"] = r.FrequencyPenalty * 2.0
 | |
| 
 | |
| 	options["presence_penalty"] = r.PresencePenalty * 2.0
 | |
| 
 | |
| 	if r.TopP != 0.0 {
 | |
| 		options["top_p"] = r.TopP
 | |
| 	} else {
 | |
| 		options["top_p"] = 1.0
 | |
| 	}
 | |
| 
 | |
| 	return api.GenerateRequest{
 | |
| 		Model:   r.Model,
 | |
| 		Prompt:  r.Prompt,
 | |
| 		Options: options,
 | |
| 		Stream:  &r.Stream,
 | |
| 		Suffix:  r.Suffix,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| type BaseWriter struct {
 | |
| 	gin.ResponseWriter
 | |
| }
 | |
| 
 | |
| type ChatWriter struct {
 | |
| 	stream bool
 | |
| 	id     string
 | |
| 	BaseWriter
 | |
| }
 | |
| 
 | |
| type CompleteWriter struct {
 | |
| 	stream bool
 | |
| 	id     string
 | |
| 	BaseWriter
 | |
| }
 | |
| 
 | |
| type ListWriter struct {
 | |
| 	BaseWriter
 | |
| }
 | |
| 
 | |
| type RetrieveWriter struct {
 | |
| 	BaseWriter
 | |
| 	model string
 | |
| }
 | |
| 
 | |
| type EmbedWriter struct {
 | |
| 	BaseWriter
 | |
| 	model string
 | |
| }
 | |
| 
 | |
| func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
 | |
| 	var serr api.StatusError
 | |
| 	err := json.Unmarshal(data, &serr)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 | |
| 	err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return len(data), nil
 | |
| }
 | |
| 
 | |
| func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 | |
| 	var chatResponse api.ChatResponse
 | |
| 	err := json.Unmarshal(data, &chatResponse)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	// chat chunk
 | |
| 	if w.stream {
 | |
| 		d, err := json.Marshal(toChunk(w.id, chatResponse))
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 
 | |
| 		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
 | |
| 		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 
 | |
| 		if chatResponse.Done {
 | |
| 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 | |
| 			if err != nil {
 | |
| 				return 0, err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		return len(data), nil
 | |
| 	}
 | |
| 
 | |
| 	// chat completion
 | |
| 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 | |
| 	err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return len(data), nil
 | |
| }
 | |
| 
 | |
| func (w *ChatWriter) Write(data []byte) (int, error) {
 | |
| 	code := w.ResponseWriter.Status()
 | |
| 	if code != http.StatusOK {
 | |
| 		return w.writeError(code, data)
 | |
| 	}
 | |
| 
 | |
| 	return w.writeResponse(data)
 | |
| }
 | |
| 
 | |
| func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
 | |
| 	var generateResponse api.GenerateResponse
 | |
| 	err := json.Unmarshal(data, &generateResponse)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	// completion chunk
 | |
| 	if w.stream {
 | |
| 		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 
 | |
| 		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
 | |
| 		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 
 | |
| 		if generateResponse.Done {
 | |
| 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 | |
| 			if err != nil {
 | |
| 				return 0, err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		return len(data), nil
 | |
| 	}
 | |
| 
 | |
| 	// completion
 | |
| 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 | |
| 	err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return len(data), nil
 | |
| }
 | |
| 
 | |
| func (w *CompleteWriter) Write(data []byte) (int, error) {
 | |
| 	code := w.ResponseWriter.Status()
 | |
| 	if code != http.StatusOK {
 | |
| 		return w.writeError(code, data)
 | |
| 	}
 | |
| 
 | |
| 	return w.writeResponse(data)
 | |
| }
 | |
| 
 | |
| func (w *ListWriter) writeResponse(data []byte) (int, error) {
 | |
| 	var listResponse api.ListResponse
 | |
| 	err := json.Unmarshal(data, &listResponse)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 | |
| 	err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return len(data), nil
 | |
| }
 | |
| 
 | |
| func (w *ListWriter) Write(data []byte) (int, error) {
 | |
| 	code := w.ResponseWriter.Status()
 | |
| 	if code != http.StatusOK {
 | |
| 		return w.writeError(code, data)
 | |
| 	}
 | |
| 
 | |
| 	return w.writeResponse(data)
 | |
| }
 | |
| 
 | |
| func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
 | |
| 	var showResponse api.ShowResponse
 | |
| 	err := json.Unmarshal(data, &showResponse)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	// retrieve completion
 | |
| 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 | |
| 	err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return len(data), nil
 | |
| }
 | |
| 
 | |
| func (w *RetrieveWriter) Write(data []byte) (int, error) {
 | |
| 	code := w.ResponseWriter.Status()
 | |
| 	if code != http.StatusOK {
 | |
| 		return w.writeError(code, data)
 | |
| 	}
 | |
| 
 | |
| 	return w.writeResponse(data)
 | |
| }
 | |
| 
 | |
| func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
 | |
| 	var embedResponse api.EmbedResponse
 | |
| 	err := json.Unmarshal(data, &embedResponse)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 | |
| 	err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	return len(data), nil
 | |
| }
 | |
| 
 | |
| func (w *EmbedWriter) Write(data []byte) (int, error) {
 | |
| 	code := w.ResponseWriter.Status()
 | |
| 	if code != http.StatusOK {
 | |
| 		return w.writeError(code, data)
 | |
| 	}
 | |
| 
 | |
| 	return w.writeResponse(data)
 | |
| }
 | |
| 
 | |
| func ListMiddleware() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		w := &ListWriter{
 | |
| 			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
 | |
| 		}
 | |
| 
 | |
| 		c.Writer = w
 | |
| 
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func RetrieveMiddleware() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		var b bytes.Buffer
 | |
| 		if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		c.Request.Body = io.NopCloser(&b)
 | |
| 
 | |
| 		// response writer
 | |
| 		w := &RetrieveWriter{
 | |
| 			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
 | |
| 			model:      c.Param("model"),
 | |
| 		}
 | |
| 
 | |
| 		c.Writer = w
 | |
| 
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func CompletionsMiddleware() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		var req CompletionRequest
 | |
| 		err := c.ShouldBindJSON(&req)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		var b bytes.Buffer
 | |
| 		genReq, err := fromCompleteRequest(req)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if err := json.NewEncoder(&b).Encode(genReq); err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		c.Request.Body = io.NopCloser(&b)
 | |
| 
 | |
| 		w := &CompleteWriter{
 | |
| 			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
 | |
| 			stream:     req.Stream,
 | |
| 			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
 | |
| 		}
 | |
| 
 | |
| 		c.Writer = w
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func EmbeddingsMiddleware() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		var req EmbedRequest
 | |
| 		err := c.ShouldBindJSON(&req)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if req.Input == "" {
 | |
| 			req.Input = []string{""}
 | |
| 		}
 | |
| 
 | |
| 		if req.Input == nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if v, ok := req.Input.([]any); ok && len(v) == 0 {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		var b bytes.Buffer
 | |
| 		if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		c.Request.Body = io.NopCloser(&b)
 | |
| 
 | |
| 		w := &EmbedWriter{
 | |
| 			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
 | |
| 			model:      req.Model,
 | |
| 		}
 | |
| 
 | |
| 		c.Writer = w
 | |
| 
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func ChatMiddleware() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		var req ChatCompletionRequest
 | |
| 		err := c.ShouldBindJSON(&req)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if len(req.Messages) == 0 {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		var b bytes.Buffer
 | |
| 
 | |
| 		chatReq, err := fromChatRequest(req)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		c.Request.Body = io.NopCloser(&b)
 | |
| 
 | |
| 		w := &ChatWriter{
 | |
| 			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
 | |
| 			stream:     req.Stream,
 | |
| 			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
 | |
| 		}
 | |
| 
 | |
| 		c.Writer = w
 | |
| 
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 |