mirror of https://github.com/ollama/ollama.git
Compare commits
6 Commits
d09cefdc8e
...
225a2ab93a
Author | SHA1 | Date |
---|---|---|
|
225a2ab93a | |
|
bc71278670 | |
|
918231931c | |
|
04c1849878 | |
|
2c2f4deaa9 | |
|
04d9e178de |
|
@ -442,15 +442,18 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
}
|
}
|
||||||
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
|
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
|
||||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
pathNeeded := true
|
pathNeeded := true
|
||||||
|
ollamaPathNeeded := true
|
||||||
extraDone := make([]bool, len(extraEnvs))
|
extraDone := make([]bool, len(extraEnvs))
|
||||||
for i := range cmd.Env {
|
for i := range cmd.Env {
|
||||||
cmp := strings.SplitN(cmd.Env[i], "=", 2)
|
cmp := strings.SplitN(cmd.Env[i], "=", 2)
|
||||||
if strings.EqualFold(cmp[0], pathEnv) {
|
if strings.EqualFold(cmp[0], pathEnv) {
|
||||||
cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||||
pathNeeded = false
|
pathNeeded = false
|
||||||
|
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||||
|
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator))
|
||||||
|
ollamaPathNeeded = false
|
||||||
} else {
|
} else {
|
||||||
for j := range extraEnvs {
|
for j := range extraEnvs {
|
||||||
if extraDone[j] {
|
if extraDone[j] {
|
||||||
|
@ -467,6 +470,9 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||||
if pathNeeded {
|
if pathNeeded {
|
||||||
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
|
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
|
||||||
}
|
}
|
||||||
|
if ollamaPathNeeded {
|
||||||
|
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
||||||
|
}
|
||||||
for i := range extraDone {
|
for i := range extraDone {
|
||||||
if !extraDone[i] {
|
if !extraDone[i] {
|
||||||
cmd.Env = append(cmd.Env, extraEnvs[i])
|
cmd.Env = append(cmd.Env, extraEnvs[i])
|
||||||
|
|
|
@ -359,20 +359,22 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||||
s.cmd.Stderr = s.status
|
s.cmd.Stderr = s.status
|
||||||
s.cmd.SysProcAttr = LlamaServerSysProcAttr
|
s.cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||||
|
|
||||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
|
||||||
|
|
||||||
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||||
envWorkarounds := gpus.GetVisibleDevicesEnv()
|
envWorkarounds := gpus.GetVisibleDevicesEnv()
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
|
|
||||||
// Update or add the path variable with our adjusted version
|
// Update or add the path variable with our adjusted version
|
||||||
pathNeeded := true
|
pathNeeded := true
|
||||||
|
ollamaPathNeeded := true
|
||||||
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
||||||
for i := range s.cmd.Env {
|
for i := range s.cmd.Env {
|
||||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||||
if strings.EqualFold(cmp[0], pathEnv) {
|
if strings.EqualFold(cmp[0], pathEnv) {
|
||||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||||
pathNeeded = false
|
pathNeeded = false
|
||||||
|
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||||
|
s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator))
|
||||||
|
ollamaPathNeeded = false
|
||||||
} else if len(envWorkarounds) != 0 {
|
} else if len(envWorkarounds) != 0 {
|
||||||
for j, kv := range envWorkarounds {
|
for j, kv := range envWorkarounds {
|
||||||
tmp := strings.SplitN(kv, "=", 2)
|
tmp := strings.SplitN(kv, "=", 2)
|
||||||
|
@ -386,6 +388,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||||
if pathNeeded {
|
if pathNeeded {
|
||||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||||
}
|
}
|
||||||
|
if ollamaPathNeeded {
|
||||||
|
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||||
|
}
|
||||||
for i, done := range envWorkaroundDone {
|
for i, done := range envWorkaroundDone {
|
||||||
if !done {
|
if !done {
|
||||||
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
||||||
|
|
|
@ -0,0 +1,424 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaseWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatWriter struct {
|
||||||
|
stream bool
|
||||||
|
streamOptions *openai.StreamOptions
|
||||||
|
id string
|
||||||
|
toolCallSent bool
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompleteWriter struct {
|
||||||
|
stream bool
|
||||||
|
streamOptions *openai.StreamOptions
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetrieveWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||||
|
var serr api.StatusError
|
||||||
|
err := json.Unmarshal(data, &serr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var chatResponse api.ChatResponse
|
||||||
|
err := json.Unmarshal(data, &chatResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// chat chunk
|
||||||
|
if w.stream {
|
||||||
|
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
|
w.toolCallSent = true
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := openai.ToUsage(chatResponse)
|
||||||
|
c.Usage = &u
|
||||||
|
c.Choices = []openai.ChunkChoice{}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chat completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var generateResponse api.GenerateResponse
|
||||||
|
err := json.Unmarshal(data, &generateResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion chunk
|
||||||
|
if w.stream {
|
||||||
|
c := openai.ToCompleteChunk(w.id, generateResponse)
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
c.Usage = &openai.Usage{}
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := openai.ToUsageGenerate(generateResponse)
|
||||||
|
c.Usage = &u
|
||||||
|
c.Choices = []openai.CompleteChunkChoice{}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var listResponse api.ListResponse
|
||||||
|
err := json.Unmarshal(data, &listResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var showResponse api.ShowResponse
|
||||||
|
err := json.Unmarshal(data, &showResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var embedResponse api.EmbedResponse
|
||||||
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
w := &ListWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &RetrieveWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: c.Param("model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.CompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
genReq, err := openai.FromCompleteRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &CompleteWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == "" {
|
||||||
|
req.Input = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &EmbedWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.ChatCompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
|
||||||
|
chatReq, err := openai.FromChatRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &ChatWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,928 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
prefix = `data:image/jpeg;base64,`
|
||||||
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
False = false
|
||||||
|
True = true
|
||||||
|
)
|
||||||
|
|
||||||
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.ChatRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "chat handler",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with options",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: json.RawMessage(`"json"`),
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: json.RawMessage(`"json"`),
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with image content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "` + prefix + image + `"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Images: []api.ImageData{
|
||||||
|
func() []byte {
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(image)
|
||||||
|
return img
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools and content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Let's see what the weather is like in Paris",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools and empty content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools and thinking content",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "Let's see what the weather is like in Paris",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool response with call ID",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||||
|
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "The weather in Paris is 20 degrees Celsius",
|
||||||
|
ToolName: "get_current_weather",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool response with name",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||||
|
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "The weather in Paris is 20 degrees Celsius",
|
||||||
|
ToolName: "get_current_weather",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["location"],
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Defs any `json:"$defs,omitempty"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]api.ToolProperty `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Enum: []any{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler error forwarding",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": 2}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "invalid message content type: float64",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() { capturedRequest = nil }()
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||||
|
t.Fatalf("requests did not match: %+v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
|
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.GenerateRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "completions handler",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream with usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler error forwarding",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"temperature": null,
|
||||||
|
"stop": [1, 2],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "invalid type for 'stop' field: float64",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.EmbedRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "embed handler single input",
|
||||||
|
body: `{
|
||||||
|
"input": "Hello",
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
req: api.EmbedRequest{
|
||||||
|
Input: "Hello",
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embed handler batch input",
|
||||||
|
body: `{
|
||||||
|
"input": ["Hello", "World"],
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
req: api.EmbedRequest{
|
||||||
|
Input: []any{"Hello", "World"},
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embed handler error forwarding",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "invalid input",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
endpoint func(c *gin.Context)
|
||||||
|
resp string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "list handler",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{
|
||||||
|
{
|
||||||
|
Name: "test-model",
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "test-model",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "library"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list handler empty output",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"object": "list",
|
||||||
|
"data": null
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ListMiddleware())
|
||||||
|
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrieveMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
endpoint func(c *gin.Context)
|
||||||
|
resp string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "retrieve handler",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"id":"test-model",
|
||||||
|
"object":"model",
|
||||||
|
"created":1686935002,
|
||||||
|
"owned_by":"library"}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve handler error forwarding",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"error": {
|
||||||
|
"code": null,
|
||||||
|
"message": "model not found",
|
||||||
|
"param": null,
|
||||||
|
"type": "api_error"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(RetrieveMiddleware())
|
||||||
|
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
465
openai/openai.go
465
openai/openai.go
|
@ -1,21 +1,17 @@
|
||||||
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
@ -220,11 +216,12 @@ func NewError(code int, message string) ErrorResponse {
|
||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toUsage(r api.ChatResponse) Usage {
|
// ToUsage converts an api.ChatResponse to Usage
|
||||||
|
func ToUsage(r api.ChatResponse) Usage {
|
||||||
return Usage{
|
return Usage{
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.Metrics.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.Metrics.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +253,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||||
return toolCalls
|
return toolCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
||||||
|
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
|
@ -276,12 +274,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}}, Usage: toUsage(r),
|
}}, Usage: ToUsage(r),
|
||||||
DebugInfo: r.DebugInfo,
|
DebugInfo: r.DebugInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
||||||
|
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletionChunk{
|
return ChatCompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
|
@ -305,15 +304,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||||
|
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||||
return Usage{
|
return Usage{
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.Metrics.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.Metrics.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
// ToCompletion converts an api.GenerateResponse to Completion
|
||||||
|
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return Completion{
|
return Completion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
|
@ -330,11 +331,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: toUsageGenerate(r),
|
Usage: ToUsageGenerate(r),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
||||||
|
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
return CompletionChunk{
|
return CompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
|
@ -354,7 +356,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toListCompletion(r api.ListResponse) ListCompletion {
|
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||||
|
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||||
var data []Model
|
var data []Model
|
||||||
for _, m := range r.Models {
|
for _, m := range r.Models {
|
||||||
data = append(data, Model{
|
data = append(data, Model{
|
||||||
|
@ -371,7 +374,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||||
|
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
if r.Embeddings != nil {
|
if r.Embeddings != nil {
|
||||||
var data []Embedding
|
var data []Embedding
|
||||||
for i, e := range r.Embeddings {
|
for i, e := range r.Embeddings {
|
||||||
|
@ -396,7 +400,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
return EmbeddingList{}
|
return EmbeddingList{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toModel(r api.ShowResponse, m string) Model {
|
// ToModel converts an api.ShowResponse to Model
|
||||||
|
func ToModel(r api.ShowResponse, m string) Model {
|
||||||
return Model{
|
return Model{
|
||||||
Id: m,
|
Id: m,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
|
@ -405,7 +410,8 @@ func toModel(r api.ShowResponse, m string) Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
||||||
|
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
toolName := ""
|
toolName := ""
|
||||||
|
@ -609,7 +615,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||||
return apiToolCalls, nil
|
return apiToolCalls, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
||||||
|
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
options := make(map[string]any)
|
options := make(map[string]any)
|
||||||
|
|
||||||
switch stop := r.Stop.(type) {
|
switch stop := r.Stop.(type) {
|
||||||
|
@ -660,413 +667,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
DebugRenderOnly: r.DebugRenderOnly,
|
DebugRenderOnly: r.DebugRenderOnly,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaseWriter struct {
|
|
||||||
gin.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatWriter struct {
|
|
||||||
stream bool
|
|
||||||
streamOptions *StreamOptions
|
|
||||||
id string
|
|
||||||
toolCallSent bool
|
|
||||||
BaseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompleteWriter struct {
|
|
||||||
stream bool
|
|
||||||
streamOptions *StreamOptions
|
|
||||||
id string
|
|
||||||
BaseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type ListWriter struct {
|
|
||||||
BaseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
type RetrieveWriter struct {
|
|
||||||
BaseWriter
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbedWriter struct {
|
|
||||||
BaseWriter
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
|
||||||
var serr api.StatusError
|
|
||||||
err := json.Unmarshal(data, &serr)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var chatResponse api.ChatResponse
|
|
||||||
err := json.Unmarshal(data, &chatResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// chat chunk
|
|
||||||
if w.stream {
|
|
||||||
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
|
||||||
w.toolCallSent = true
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatResponse.Done {
|
|
||||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
||||||
u := toUsage(chatResponse)
|
|
||||||
c.Usage = &u
|
|
||||||
c.Choices = []ChunkChoice{}
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// chat completion
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var generateResponse api.GenerateResponse
|
|
||||||
err := json.Unmarshal(data, &generateResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// completion chunk
|
|
||||||
if w.stream {
|
|
||||||
c := toCompleteChunk(w.id, generateResponse)
|
|
||||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
||||||
c.Usage = &Usage{}
|
|
||||||
}
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if generateResponse.Done {
|
|
||||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
||||||
u := toUsageGenerate(generateResponse)
|
|
||||||
c.Usage = &u
|
|
||||||
c.Choices = []CompleteChunkChoice{}
|
|
||||||
d, err := json.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// completion
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var listResponse api.ListResponse
|
|
||||||
err := json.Unmarshal(data, &listResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var showResponse api.ShowResponse
|
|
||||||
err := json.Unmarshal(data, &showResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrieve completion
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
||||||
var embedResponse api.EmbedResponse
|
|
||||||
err := json.Unmarshal(data, &embedResponse)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
|
||||||
code := w.ResponseWriter.Status()
|
|
||||||
if code != http.StatusOK {
|
|
||||||
return w.writeError(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.writeResponse(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
w := &ListWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RetrieveMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
// response writer
|
|
||||||
w := &RetrieveWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
model: c.Param("model"),
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CompletionsMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var req CompletionRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
genReq, err := fromCompleteRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
w := &CompleteWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
stream: req.Stream,
|
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
||||||
streamOptions: req.StreamOptions,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var req EmbedRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Input == "" {
|
|
||||||
req.Input = []string{""}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Input == nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
w := &EmbedWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
model: req.Model,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ChatMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
var req ChatCompletionRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.Messages) == 0 {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
|
|
||||||
chatReq, err := fromChatRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
|
||||||
|
|
||||||
w := &ChatWriter{
|
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
||||||
stream: req.Stream,
|
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
||||||
streamOptions: req.StreamOptions,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Writer = w
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -179,7 +179,7 @@ function buildROCm() {
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component "HIP" --strip
|
& cmake --install build --component "HIP" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
rm -f $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906*
|
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,9 +79,13 @@ status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
||||||
status "Downloading Linux ${ARCH} bundle"
|
status "Downloading Linux ${ARCH} bundle"
|
||||||
curl --fail --show-error --location --progress-bar \
|
if ! curl --fail --show-error --location --progress-bar --output "ollama-linux-${ARCH}.tgz${VER_PARAM}" \
|
||||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}"; then
|
||||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
error "Failed to download or extract Ollama"
|
||||||
|
fi
|
||||||
|
status "Unpacking Ollama"
|
||||||
|
$SUDO tar -xzf "ollama-linux-${ARCH}.tgz${VER_PARAM}" -C "$OLLAMA_INSTALL_DIR"
|
||||||
|
rm "ollama-linux-${ARCH}.tgz${VER_PARAM}"
|
||||||
|
|
||||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||||
status "Making ollama accessible in the PATH in $BINDIR"
|
status "Making ollama accessible in the PATH in $BINDIR"
|
||||||
|
|
|
@ -37,8 +37,8 @@ import (
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/middleware"
|
||||||
"github.com/ollama/ollama/model/parsers"
|
"github.com/ollama/ollama/model/parsers"
|
||||||
"github.com/ollama/ollama/openai"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
@ -1449,11 +1449,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
|
|
||||||
// Inference (OpenAI compatibility)
|
// Inference (OpenAI compatibility)
|
||||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
|
|
Loading…
Reference in New Issue