Compare commits

...

7 Commits

Author SHA1 Message Date
alextibbles 5b8de222fa
Merge 5123129def into bd15eba4e4 2025-10-07 19:18:42 -04:00
Daniel Hiltgen bd15eba4e4
Bring back escape valve for llm libraries and fix Jetpack6 crash (#12529)
* Bring back escape valve for llm libraries

If the new discovery logic picks the wrong library, this gives users the
ability to force a specific one using the same pattern as before. This
can also potentially speed up bootstrap discovery if one of the libraries
takes a long time to load and ultimately bind to no devices.  For example
unsupported AMD iGPUS can sometimes take a while to discover and rule out.

* Bypass extra discovery on jetpack systems

On at least Jetpack6, cuda_v12 appears to expose the iGPU, but crashes later on in
cublasInit so if we detect a Jetpack, short-circuit and use that variant.
2025-10-07 16:06:14 -07:00
Devon Rifkin bc71278670
Merge pull request #12509 from ollama/drifkin/oai-compat-refactor
openai: refactor to split compat layer and middleware
2025-10-06 16:22:08 -07:00
Daniel Hiltgen 918231931c
win: fix build script (#12513) 2025-10-06 14:46:45 -07:00
Daniel Hiltgen 04c1849878
discovery: prevent dup OLLAMA_LIBRARY_PATH (#12514)
This variable isn't currently documented or intended as something the user can
override, but if the user happens to set OLLAMA_LIBRARY_PATH we were doubling
this in the subprocess environment which will cause problems with the new
bootstrap discovery logic.
2025-10-06 14:36:44 -07:00
Devon Rifkin 2c2f4deaa9 openai: refactor to split compat layer and middleware
This makes the core openai compat layer independent of the middleware
that adapts it to our particular gin routes
2025-10-05 14:18:56 -07:00
Alex Tibbles 5123129def
build: update go stdlib to 1.24.6 2025-08-08 14:32:03 -04:00
11 changed files with 1543 additions and 1322 deletions

View File

@ -6,7 +6,9 @@ import (
"log/slog"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@ -146,3 +148,35 @@ func GetSystemInfo() SystemInfo {
GPUs: gpus,
}
}
func cudaJetpack() string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".")
if len(ver) > 0 {
return "jetpack" + ver[0]
}
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
r := regexp.MustCompile(` R(\d+) `)
m := r.FindSubmatch(data)
if len(m) != 2 {
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
} else {
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
// https://developer.nvidia.com/embedded/jetpack-archive
switch l4t {
case 35:
return "jetpack5"
case 36:
return "jetpack6"
default:
// Newer Jetson systems use the SBSU runtime
slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data))
}
}
}
}
}
return ""
}

View File

@ -78,6 +78,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
slog.Info("discovering available GPUs...")
requested := envconfig.LLMLibrary()
jetpack := cudaJetpack()
// For our initial discovery pass, we gather all the known GPUs through
// all the libraries that were detected. This pass may include GPUs that
@ -86,6 +88,14 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
// times concurrently leading to memory contention
for dir := range libDirs {
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
}
}
if dir == "" {
dirs = []string{LibOllamaPath}
} else {
@ -442,15 +452,18 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
cmd.Stderr = os.Stderr
}
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
pathNeeded := true
ollamaPathNeeded := true
extraDone := make([]bool, len(extraEnvs))
for i := range cmd.Env {
cmp := strings.SplitN(cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator))
ollamaPathNeeded = false
} else {
for j := range extraEnvs {
if extraDone[j] {
@ -467,6 +480,9 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
if pathNeeded {
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
}
if ollamaPathNeeded {
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
}
for i := range extraDone {
if !extraDone[i] {
cmd.Env = append(cmd.Env, extraEnvs[i])

View File

@ -48,16 +48,10 @@ Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
**Experimental LLM Library Override**
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use:
```shell
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
```
You can see what features your CPU has with the following.
```shell
cat /proc/cpuinfo| grep flags | head -1
OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve
```
## Installing older or pre-release versions on Linux

2
go.mod
View File

@ -1,6 +1,6 @@
module github.com/ollama/ollama
go 1.24.0
go 1.24.6
require (
github.com/containerd/console v1.0.3

View File

@ -359,20 +359,22 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Stderr = s.status
s.cmd.SysProcAttr = LlamaServerSysProcAttr
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
envWorkarounds := gpus.GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path variable with our adjusted version
pathNeeded := true
ollamaPathNeeded := true
envWorkaroundDone := make([]bool, len(envWorkarounds))
for i := range s.cmd.Env {
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator))
ollamaPathNeeded = false
} else if len(envWorkarounds) != 0 {
for j, kv := range envWorkarounds {
tmp := strings.SplitN(kv, "=", 2)
@ -386,6 +388,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
if pathNeeded {
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
}
if ollamaPathNeeded {
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
}
for i, done := range envWorkaroundDone {
if !done {
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])

424
middleware/openai.go Normal file
View File

@ -0,0 +1,424 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
)
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *openai.StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *openai.StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsage(chatResponse)
c.Usage = &u
c.Choices = []openai.ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := openai.ToCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &openai.Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []openai.CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := openai.FromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := openai.FromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}

928
middleware/openai_test.go Normal file
View File

@ -0,0 +1,928 @@
package middleware
import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
)
const (
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
var (
False = false
True = true
)
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err openai.ErrorResponse
}
var capturedRequest *api.ChatRequest
testCases := []testCase{
{
name: "chat handler",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with options",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
{
name: "chat handler with streaming usage",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"stream_options": {"include_usage": true},
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
{
name: "chat handler with image content",
body: `{
"model": "test-model",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": "` + prefix + image + `"
}
}
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString(image)
return img
}(),
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools and content",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
Content: "Let's see what the weather is like in Paris",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools and empty content",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools and thinking content",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
Thinking: "Let's see what the weather is like in Paris",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "tool response with call ID",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
{
Role: "tool",
Content: "The weather in Paris is 20 degrees Celsius",
ToolName: "get_current_weather",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "tool response with name",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
{
Role: "tool",
Content: "The weather in Paris is 20 degrees Celsius",
ToolName: "get_current_weather",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with streaming tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris?"}
],
"stream": true,
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris?",
},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &True,
},
},
{
name: "chat handler error forwarding",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": 2}
]
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "invalid message content type: float64",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
return
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
}
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
name: "completions handler",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &False,
},
},
{
name: "completions handler stream",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler stream with usage",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"stream_options": {"include_usage": true},
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler error forwarding",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": null,
"stop": [1, 2],
"suffix": "suffix"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "invalid type for 'stop' field: float64",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.EmbedRequest
err openai.ErrorResponse
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
name: "embed handler single input",
body: `{
"input": "Hello",
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: "Hello",
Model: "test-model",
},
},
{
name: "embed handler batch input",
body: `{
"input": ["Hello", "World"],
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: []any{"Hello", "World"},
Model: "test-model",
},
},
{
name: "embed handler error forwarding",
body: `{
"model": "test-model"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "invalid input",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
}
}
func TestListMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "list handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
},
},
})
},
resp: `{
"object": "list",
"data": [
{
"id": "test-model",
"object": "model",
"created": 1686935002,
"owned_by": "library"
}
]
}`,
},
{
name: "list handler empty output",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{})
},
resp: `{
"object": "list",
"data": null
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(ListMiddleware())
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "retrieve handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
})
},
resp: `{
"id":"test-model",
"object":"model",
"created":1686935002,
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
resp: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "api_error"
}
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}

View File

@ -1,21 +1,17 @@
// openai package provides middleware for partial compatibility with the OpenAI REST API
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
@ -220,11 +216,12 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toUsage(r api.ChatResponse) Usage {
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
@ -256,7 +253,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{
Id: id,
@ -276,12 +274,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
return nil
}(r.DoneReason),
}}, Usage: toUsage(r),
}}, Usage: ToUsage(r),
DebugInfo: r.DebugInfo,
}
}
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{
Id: id,
@ -305,15 +304,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
}
}
func toUsageGenerate(r api.GenerateResponse) Usage {
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
// ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
@ -330,11 +331,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil
}(r.DoneReason),
}},
Usage: toUsageGenerate(r),
Usage: ToUsageGenerate(r),
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
@ -354,7 +356,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
@ -371,7 +374,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
@ -396,7 +400,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
// ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
@ -405,7 +410,8 @@ func toModel(r api.ShowResponse, m string) Model {
}
}
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message
for _, msg := range r.Messages {
toolName := ""
@ -609,7 +615,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
return apiToolCalls, nil
}
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
@ -660,413 +667,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := toChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
c.Usage = &u
c.Choices = []ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -179,7 +179,7 @@ function buildROCm() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
rm -f $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906*
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
}
}
}

View File

@ -37,8 +37,8 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@ -1449,11 +1449,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
if rc != nil {
// wrap old with new