mirror of https://github.com/ollama/ollama.git
openai: refactor to split compat layer and middleware
This makes the core openai compat layer independent of the middleware that adapts it to our particular gin routes
This commit is contained in:
parent
e4340667e3
commit
2c2f4deaa9
|
@ -0,0 +1,424 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaseWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatWriter struct {
|
||||||
|
stream bool
|
||||||
|
streamOptions *openai.StreamOptions
|
||||||
|
id string
|
||||||
|
toolCallSent bool
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompleteWriter struct {
|
||||||
|
stream bool
|
||||||
|
streamOptions *openai.StreamOptions
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetrieveWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||||
|
var serr api.StatusError
|
||||||
|
err := json.Unmarshal(data, &serr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var chatResponse api.ChatResponse
|
||||||
|
err := json.Unmarshal(data, &chatResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// chat chunk
|
||||||
|
if w.stream {
|
||||||
|
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
|
w.toolCallSent = true
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := openai.ToUsage(chatResponse)
|
||||||
|
c.Usage = &u
|
||||||
|
c.Choices = []openai.ChunkChoice{}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chat completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var generateResponse api.GenerateResponse
|
||||||
|
err := json.Unmarshal(data, &generateResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion chunk
|
||||||
|
if w.stream {
|
||||||
|
c := openai.ToCompleteChunk(w.id, generateResponse)
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
c.Usage = &openai.Usage{}
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := openai.ToUsageGenerate(generateResponse)
|
||||||
|
c.Usage = &u
|
||||||
|
c.Choices = []openai.CompleteChunkChoice{}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var listResponse api.ListResponse
|
||||||
|
err := json.Unmarshal(data, &listResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var showResponse api.ShowResponse
|
||||||
|
err := json.Unmarshal(data, &showResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var embedResponse api.EmbedResponse
|
||||||
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
w := &ListWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &RetrieveWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: c.Param("model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.CompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
genReq, err := openai.FromCompleteRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &CompleteWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == "" {
|
||||||
|
req.Input = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &EmbedWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.ChatCompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
|
||||||
|
chatReq, err := openai.FromChatRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &ChatWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,928 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
prefix = `data:image/jpeg;base64,`
|
||||||
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
False = false
|
||||||
|
True = true
|
||||||
|
)
|
||||||
|
|
||||||
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.ChatRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "chat handler",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with options",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: json.RawMessage(`"json"`),
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: json.RawMessage(`"json"`),
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with image content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "` + prefix + image + `"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Images: []api.ImageData{
|
||||||
|
func() []byte {
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(image)
|
||||||
|
return img
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools and content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Let's see what the weather is like in Paris",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools and empty content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools and thinking content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "Let's see what the weather is like in Paris",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool response with call ID",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||||
|
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "The weather in Paris is 20 degrees Celsius",
|
||||||
|
ToolName: "get_current_weather",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool response with name",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||||
|
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "The weather in Paris is 20 degrees Celsius",
|
||||||
|
ToolName: "get_current_weather",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["location"],
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Defs any `json:"$defs,omitempty"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]api.ToolProperty `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Enum: []any{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler error forwarding",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": 2}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "invalid message content type: float64",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() { capturedRequest = nil }()
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||||
|
t.Fatalf("requests did not match: %+v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
|
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.GenerateRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "completions handler",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream with usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler error forwarding",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"temperature": null,
|
||||||
|
"stop": [1, 2],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "invalid type for 'stop' field: float64",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.EmbedRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "embed handler single input",
|
||||||
|
body: `{
|
||||||
|
"input": "Hello",
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
req: api.EmbedRequest{
|
||||||
|
Input: "Hello",
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embed handler batch input",
|
||||||
|
body: `{
|
||||||
|
"input": ["Hello", "World"],
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
req: api.EmbedRequest{
|
||||||
|
Input: []any{"Hello", "World"},
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embed handler error forwarding",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "invalid input",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
endpoint func(c *gin.Context)
|
||||||
|
resp string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "list handler",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{
|
||||||
|
{
|
||||||
|
Name: "test-model",
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "test-model",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "library"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list handler empty output",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"object": "list",
|
||||||
|
"data": null
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ListMiddleware())
|
||||||
|
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrieveMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
endpoint func(c *gin.Context)
|
||||||
|
resp string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "retrieve handler",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"id":"test-model",
|
||||||
|
"object":"model",
|
||||||
|
"created":1686935002,
|
||||||
|
"owned_by":"library"}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve handler error forwarding",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"error": {
|
||||||
|
"code": null,
|
||||||
|
"message": "model not found",
|
||||||
|
"param": null,
|
||||||
|
"type": "api_error"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(RetrieveMiddleware())
|
||||||
|
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
465
openai/openai.go
465
openai/openai.go
|
@ -1,21 +1,17 @@
|
||||||
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
@ -220,11 +216,12 @@ func NewError(code int, message string) ErrorResponse {
|
||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toUsage(r api.ChatResponse) Usage {
|
// ToUsage converts an api.ChatResponse to Usage
|
||||||
|
func ToUsage(r api.ChatResponse) Usage {
|
||||||
return Usage{
|
return Usage{
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.Metrics.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.Metrics.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +253,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||||
return toolCalls
|
return toolCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
||||||
|
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
|
@ -276,12 +274,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}}, Usage: toUsage(r),
|
}}, Usage: ToUsage(r),
|
||||||
DebugInfo: r.DebugInfo,
|
DebugInfo: r.DebugInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
||||||
|
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletionChunk{
|
return ChatCompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
|
@ -305,15 +304,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||||
|
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||||
return Usage{
|
return Usage{
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.Metrics.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.Metrics.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
// ToCompletion converts an api.GenerateResponse to Completion
|
||||||
|
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return Completion{
|
return Completion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
|
@ -330,11 +331,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: toUsageGenerate(r),
|
Usage: ToUsageGenerate(r),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
||||||
|
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
return CompletionChunk{
|
return CompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
|
@ -354,7 +356,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toListCompletion(r api.ListResponse) ListCompletion {
|
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||||
|
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||||
var data []Model
|
var data []Model
|
||||||
for _, m := range r.Models {
|
for _, m := range r.Models {
|
||||||
data = append(data, Model{
|
data = append(data, Model{
|
||||||
|
@ -371,7 +374,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||||
|
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
if r.Embeddings != nil {
|
if r.Embeddings != nil {
|
||||||
var data []Embedding
|
var data []Embedding
|
||||||
for i, e := range r.Embeddings {
|
for i, e := range r.Embeddings {
|
||||||
|
@ -396,7 +400,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
return EmbeddingList{}
|
return EmbeddingList{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toModel(r api.ShowResponse, m string) Model {
|
// ToModel converts an api.ShowResponse to Model
|
||||||
|
func ToModel(r api.ShowResponse, m string) Model {
|
||||||
return Model{
|
return Model{
|
||||||
Id: m,
|
Id: m,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
|
@ -405,7 +410,8 @@ func toModel(r api.ShowResponse, m string) Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
||||||
|
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
toolName := ""
|
toolName := ""
|
||||||
|
@ -609,7 +615,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||||
return apiToolCalls, nil
|
return apiToolCalls, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
||||||
|
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
options := make(map[string]any)
|
options := make(map[string]any)
|
||||||
|
|
||||||
switch stop := r.Stop.(type) {
|
switch stop := r.Stop.(type) {
|
||||||
|
@ -660,413 +667,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
DebugRenderOnly: r.DebugRenderOnly,
|
DebugRenderOnly: r.DebugRenderOnly,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaseWriter struct {
|
|
||||||
gin.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatWriter struct {
|
|
||||||
stream bool
|
|
||||||
streamOptions *StreamOptions
|
|
||||||
id string
|
|
||||||
toolCallSent bool
|
|
||||||
BaseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompleteWriter struct {
|
|
||||||
stream bool
|
|
||||||
streamOptions *StreamOptions
|
|
||||||
id string
|
|
||||||
BaseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type ListWriter struct {
|
|
||||||
BaseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type RetrieveWriter struct {
|
|
||||||
BaseWriter
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbedWriter struct {
|
|
||||||
BaseWriter
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
|
||||||
var serr api.StatusError
|
|
||||||
err := json.Unmarshal(data, &serr)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var chatResponse api.ChatResponse
|
|
||||||
err := json.Unmarshal(data, &chatResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// chat chunk
|
|
||||||
if w.stream {
|
|
||||||
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
|
||||||
w.toolCallSent = true
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatResponse.Done {
|
|
||||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
||||||
u := toUsage(chatResponse)
|
|
||||||
c.Usage = &u
|
|
||||||
c.Choices = []ChunkChoice{}
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// chat completion
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var generateResponse api.GenerateResponse
|
|
||||||
err := json.Unmarshal(data, &generateResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// completion chunk
|
|
||||||
if w.stream {
|
|
||||||
c := toCompleteChunk(w.id, generateResponse)
|
|
||||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
||||||
c.Usage = &Usage{}
|
|
||||||
}
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if generateResponse.Done {
|
|
||||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
||||||
u := toUsageGenerate(generateResponse)
|
|
||||||
c.Usage = &u
|
|
||||||
c.Choices = []CompleteChunkChoice{}
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// completion
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var listResponse api.ListResponse
|
|
||||||
err := json.Unmarshal(data, &listResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var showResponse api.ShowResponse
|
|
||||||
err := json.Unmarshal(data, &showResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrieve completion
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var embedResponse api.EmbedResponse
|
|
||||||
err := json.Unmarshal(data, &embedResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
w := &ListWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RetrieveMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
// response writer
|
|
||||||
w := &RetrieveWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
model: c.Param("model"),
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CompletionsMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var req CompletionRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
genReq, err := fromCompleteRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
w := &CompleteWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
stream: req.Stream,
|
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
||||||
streamOptions: req.StreamOptions,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var req EmbedRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Input == "" {
|
|
||||||
req.Input = []string{""}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Input == nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
w := &EmbedWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
model: req.Model,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ChatMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var req ChatCompletionRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.Messages) == 0 {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
|
|
||||||
chatReq, err := fromChatRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
w := &ChatWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
stream: req.Stream,
|
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
||||||
streamOptions: req.StreamOptions,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -37,8 +37,8 @@ import (
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/middleware"
|
||||||
"github.com/ollama/ollama/model/parsers"
|
"github.com/ollama/ollama/model/parsers"
|
||||||
"github.com/ollama/ollama/openai"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
@ -1449,11 +1449,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
|
|
||||||
// Inference (OpenAI compatibility)
|
// Inference (OpenAI compatibility)
|
||||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
|
|
Loading…
Reference in New Issue