2023-07-04 03:22:44 +08:00
package server
import (
2023-07-26 05:08:51 +08:00
"context"
2023-07-07 01:40:11 +08:00
"encoding/json"
2023-10-07 04:06:20 +08:00
"errors"
2023-07-22 14:02:12 +08:00
"fmt"
2023-07-04 03:22:44 +08:00
"io"
2023-10-07 04:06:20 +08:00
"io/fs"
2024-01-19 02:52:01 +08:00
"log/slog"
2023-07-04 03:22:44 +08:00
"net"
"net/http"
2023-07-08 03:27:43 +08:00
"os"
2023-08-31 04:35:03 +08:00
"os/signal"
2023-07-15 08:27:14 +08:00
"path/filepath"
2023-08-01 09:35:18 +08:00
"reflect"
2023-09-12 23:04:35 +08:00
"runtime"
2023-07-07 01:40:11 +08:00
"strings"
2023-07-19 02:59:42 +08:00
"sync"
2023-08-31 04:35:03 +08:00
"syscall"
2023-07-13 09:18:06 +08:00
"time"
2023-07-04 03:22:44 +08:00
2023-07-22 09:01:24 +08:00
"github.com/gin-contrib/cors"
2023-07-04 03:22:44 +08:00
"github.com/gin-gonic/gin"
2024-02-13 03:16:20 +08:00
"golang.org/x/exp/slices"
2023-07-04 03:22:44 +08:00
2023-07-04 04:32:48 +08:00
"github.com/jmorganca/ollama/api"
2023-11-30 03:00:37 +08:00
"github.com/jmorganca/ollama/gpu"
2023-07-22 04:33:56 +08:00
"github.com/jmorganca/ollama/llm"
2024-02-08 06:24:29 +08:00
"github.com/jmorganca/ollama/openai"
2023-11-15 04:30:34 +08:00
"github.com/jmorganca/ollama/parser"
2023-10-14 07:08:35 +08:00
"github.com/jmorganca/ollama/version"
2023-07-04 03:22:44 +08:00
)
2023-08-23 00:48:35 +08:00
var mode string = gin . DebugMode
2023-12-15 08:47:40 +08:00
type Server struct {
WorkDir string
}
2023-08-23 00:48:35 +08:00
func init ( ) {
switch mode {
case gin . DebugMode :
case gin . ReleaseMode :
case gin . TestMode :
default :
mode = gin . DebugMode
}
gin . SetMode ( mode )
}
2023-08-01 09:35:18 +08:00
var loaded struct {
2023-07-20 06:00:28 +08:00
mu sync . Mutex
2023-10-19 22:39:58 +08:00
runner llm . LLM
2023-07-20 06:00:28 +08:00
expireAt time . Time
expireTimer * time . Timer
2023-08-01 09:35:18 +08:00
2023-10-19 22:39:58 +08:00
* Model
* api . Options
2023-07-19 02:59:42 +08:00
}
2023-08-15 21:35:39 +08:00
var defaultSessionDuration = 5 * time . Minute
2023-08-09 03:13:22 +08:00
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
2024-01-04 01:01:42 +08:00
func load ( c * gin . Context , model * Model , opts api . Options , sessionDuration time . Duration ) error {
2023-12-06 03:57:33 +08:00
workDir := c . GetString ( "workDir" )
2023-10-19 22:39:58 +08:00
needLoad := loaded . runner == nil || // is there a model loaded?
loaded . ModelPath != model . ModelPath || // has the base model changed?
! reflect . DeepEqual ( loaded . AdapterPaths , model . AdapterPaths ) || // have the adapters changed?
! reflect . DeepEqual ( loaded . Options . Runner , opts . Runner ) // have the runner options changed?
if needLoad {
if loaded . runner != nil {
2024-01-19 02:52:01 +08:00
slog . Info ( "changing loaded model" )
2023-10-19 22:39:58 +08:00
loaded . runner . Close ( )
loaded . runner = nil
loaded . Model = nil
loaded . Options = nil
2023-07-19 02:59:42 +08:00
}
2023-07-18 03:08:10 +08:00
2023-12-01 02:30:23 +08:00
llmRunner , err := llm . New ( workDir , model . ModelPath , model . AdapterPaths , model . ProjectorPaths , opts )
2023-07-19 02:59:42 +08:00
if err != nil {
2023-10-20 02:50:45 +08:00
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
2023-11-25 02:58:09 +08:00
if errors . Is ( llm . ErrUnsupportedFormat , err ) || strings . Contains ( err . Error ( ) , "failed to load model" ) {
2023-10-20 02:50:45 +08:00
err = fmt . Errorf ( "%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`" , err , model . ShortName )
}
2024-01-04 01:01:42 +08:00
return err
2023-07-19 02:59:42 +08:00
}
2023-10-19 22:39:58 +08:00
loaded . Model = model
loaded . runner = llmRunner
loaded . Options = & opts
2023-07-20 06:00:28 +08:00
}
2023-09-22 03:38:49 +08:00
2023-08-01 09:35:18 +08:00
loaded . expireAt = time . Now ( ) . Add ( sessionDuration )
2023-08-09 03:13:22 +08:00
2023-08-01 09:35:18 +08:00
if loaded . expireTimer == nil {
loaded . expireTimer = time . AfterFunc ( sessionDuration , func ( ) {
loaded . mu . Lock ( )
defer loaded . mu . Unlock ( )
2023-07-20 06:00:28 +08:00
2023-08-01 09:35:18 +08:00
if time . Now ( ) . Before ( loaded . expireAt ) {
2023-07-20 06:00:28 +08:00
return
}
2023-10-19 22:39:58 +08:00
if loaded . runner != nil {
loaded . runner . Close ( )
2023-07-20 06:00:28 +08:00
}
2023-10-19 22:39:58 +08:00
loaded . runner = nil
loaded . Model = nil
loaded . Options = nil
2023-07-20 06:00:28 +08:00
} )
2023-07-07 01:40:11 +08:00
}
2023-09-22 03:38:49 +08:00
2023-08-01 09:35:18 +08:00
loaded . expireTimer . Reset ( sessionDuration )
2024-01-04 01:01:42 +08:00
return nil
}
func modelOptions ( model * Model , requestOpts map [ string ] interface { } ) ( api . Options , error ) {
opts := api . DefaultOptions ( )
if err := opts . FromMap ( model . Options ) ; err != nil {
return api . Options { } , err
}
if err := opts . FromMap ( requestOpts ) ; err != nil {
return api . Options { } , err
}
return opts , nil
2023-08-09 03:13:22 +08:00
}
2024-02-13 03:16:20 +08:00
func isSupportedImageType ( image [ ] byte ) bool {
contentType := http . DetectContentType ( image )
allowedTypes := [ ] string { "image/jpeg" , "image/jpg" , "image/png" }
return slices . Contains ( allowedTypes , contentType )
}
2023-08-09 03:13:22 +08:00
func GenerateHandler ( c * gin . Context ) {
loaded . mu . Lock ( )
defer loaded . mu . Unlock ( )
checkpointStart := time . Now ( )
var req api . GenerateRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
2023-12-12 05:56:22 +08:00
2023-10-19 07:08:42 +08:00
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-08-09 03:13:22 +08:00
return
}
2023-11-09 06:05:02 +08:00
// validate the request
switch {
case req . Model == "" :
2023-10-19 06:56:34 +08:00
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
return
2023-11-10 08:44:02 +08:00
case len ( req . Format ) > 0 && req . Format != "json" :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "format must be json" } )
return
2023-11-09 06:05:02 +08:00
case req . Raw && ( req . Template != "" || req . System != "" || len ( req . Context ) > 0 ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "raw mode does not support template, system, or context" } )
return
2023-10-19 06:56:34 +08:00
}
2024-02-13 03:16:20 +08:00
for _ , img := range req . Images {
if ! isSupportedImageType ( img ) {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "unsupported image format" } )
return
}
}
2024-01-04 01:01:42 +08:00
model , err := GetModel ( req . Model )
2023-08-09 03:13:22 +08:00
if err != nil {
2023-10-07 04:06:20 +08:00
var pErr * fs . PathError
2024-01-04 01:01:42 +08:00
if errors . As ( err , & pErr ) {
2023-10-07 04:06:20 +08:00
c . JSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "model '%s' not found, try pulling it first" , req . Model ) } )
2024-01-04 01:01:42 +08:00
return
}
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
opts , err := modelOptions ( model , req . Options )
if err != nil {
if errors . Is ( err , api . ErrInvalidOpts ) {
2023-12-06 03:57:33 +08:00
c . JSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2024-01-04 01:01:42 +08:00
return
2023-11-09 06:05:02 +08:00
}
2024-01-04 01:01:42 +08:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2024-01-27 06:28:02 +08:00
var sessionDuration time . Duration
if req . KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req . KeepAlive . Duration
}
2024-01-04 01:01:42 +08:00
if err := load ( c , model , opts , sessionDuration ) ; err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
2023-08-09 03:13:22 +08:00
return
}
2023-12-06 03:57:33 +08:00
// an empty request loads the model
if req . Prompt == "" && req . Template == "" && req . System == "" {
2023-12-02 03:37:17 +08:00
c . JSON ( http . StatusOK , api . GenerateResponse {
2023-12-09 18:05:43 +08:00
CreatedAt : time . Now ( ) . UTC ( ) ,
Model : req . Model ,
2023-12-16 06:25:12 +08:00
Done : true ,
} )
2023-12-05 07:01:06 +08:00
return
}
checkpointLoaded := time . Now ( )
2023-12-06 03:57:33 +08:00
var prompt string
2023-12-23 06:07:05 +08:00
var promptVars PromptVars
2023-12-06 03:57:33 +08:00
switch {
case req . Raw :
prompt = req . Prompt
case req . Prompt != "" :
if req . Template != "" {
// override the default model template
model . Template = req . Template
}
var rebuild strings . Builder
if req . Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx , err := loaded . runner . Decode ( c . Request . Context ( ) , req . Context )
if err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings . TrimPrefix ( prevCtx , " " )
rebuild . WriteString ( prevCtx )
}
2023-12-23 06:07:05 +08:00
promptVars = PromptVars {
2023-12-06 03:57:33 +08:00
System : req . System ,
Prompt : req . Prompt ,
First : len ( req . Context ) == 0 ,
2023-12-23 06:07:05 +08:00
}
2024-01-29 06:15:56 +08:00
if promptVars . System == "" {
promptVars . System = model . System
}
2024-02-01 08:31:29 +08:00
for i := range req . Images {
promptVars . Prompt += fmt . Sprintf ( " [img-%d]" , i )
}
2023-12-23 06:07:05 +08:00
p , err := model . PreResponsePrompt ( promptVars )
2023-12-05 13:16:27 +08:00
if err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-12-06 03:57:33 +08:00
rebuild . WriteString ( p )
prompt = rebuild . String ( )
2023-08-09 03:13:22 +08:00
}
2024-02-01 08:47:26 +08:00
slog . Debug ( "generate handler" , "prompt" , prompt )
2024-01-29 07:22:35 +08:00
2023-12-05 07:01:06 +08:00
ch := make ( chan any )
2023-12-06 03:57:33 +08:00
var generated strings . Builder
2023-12-05 07:01:06 +08:00
go func ( ) {
defer close ( ch )
2023-12-06 03:57:33 +08:00
fn := func ( r llm . PredictResult ) {
// Update model expiration
2023-12-05 07:01:06 +08:00
loaded . expireAt = time . Now ( ) . Add ( sessionDuration )
loaded . expireTimer . Reset ( sessionDuration )
2023-12-06 03:57:33 +08:00
// Build up the full response
if _ , err := generated . WriteString ( r . Content ) ; err != nil {
ch <- gin . H { "error" : err . Error ( ) }
return
2023-12-05 07:01:06 +08:00
}
2023-12-06 03:57:33 +08:00
resp := api . GenerateResponse {
2023-12-11 00:42:15 +08:00
Model : req . Model ,
2023-12-15 01:15:50 +08:00
CreatedAt : time . Now ( ) . UTC ( ) ,
2023-12-09 18:05:43 +08:00
Done : r . Done ,
Response : r . Content ,
2023-12-06 03:57:33 +08:00
Metrics : api . Metrics {
PromptEvalCount : r . PromptEvalCount ,
PromptEvalDuration : r . PromptEvalDuration ,
EvalCount : r . EvalCount ,
EvalDuration : r . EvalDuration ,
} ,
2023-12-05 07:01:06 +08:00
}
2023-12-15 01:15:50 +08:00
if r . Done {
resp . TotalDuration = time . Since ( checkpointStart )
resp . LoadDuration = checkpointLoaded . Sub ( checkpointStart )
if ! req . Raw {
2023-12-23 06:07:05 +08:00
// append the generated text to the history and template it if needed
promptVars . Response = generated . String ( )
result , err := model . PostResponseTemplate ( promptVars )
if err != nil {
ch <- gin . H { "error" : err . Error ( ) }
return
}
embd , err := loaded . runner . Encode ( c . Request . Context ( ) , prompt + result )
2023-12-15 01:15:50 +08:00
if err != nil {
ch <- gin . H { "error" : err . Error ( ) }
return
}
resp . Context = embd
2023-12-06 03:57:33 +08:00
}
}
ch <- resp
2023-12-05 07:01:06 +08:00
}
2024-02-01 10:56:12 +08:00
var images [ ] llm . ImageData
2024-02-01 09:39:38 +08:00
for i := range req . Images {
2024-02-01 10:56:12 +08:00
images = append ( images , llm . ImageData {
ID : i ,
Data : req . Images [ i ] ,
} )
2024-02-01 09:39:38 +08:00
}
2023-12-06 03:57:33 +08:00
// Start prediction
predictReq := llm . PredictOpts {
2024-01-04 01:01:42 +08:00
Prompt : prompt ,
Format : req . Format ,
2024-02-01 09:39:38 +08:00
Images : images ,
2024-01-04 01:01:42 +08:00
Options : opts ,
2023-12-06 03:57:33 +08:00
}
if err := loaded . runner . Predict ( c . Request . Context ( ) , predictReq , fn ) ; err != nil {
2023-12-05 07:01:06 +08:00
ch <- gin . H { "error" : err . Error ( ) }
}
} ( )
if req . Stream != nil && ! * req . Stream {
2023-12-10 23:53:38 +08:00
// Accumulate responses into the final response
var final api . GenerateResponse
2023-12-06 03:57:33 +08:00
var sb strings . Builder
2023-12-05 07:01:06 +08:00
for resp := range ch {
2023-12-10 23:53:38 +08:00
switch r := resp . ( type ) {
case api . GenerateResponse :
sb . WriteString ( r . Response )
final = r
case gin . H :
if errorMsg , ok := r [ "error" ] . ( string ) ; ok {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : errorMsg } )
return
} else {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected error format in response" } )
return
}
default :
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected error" } )
2023-12-05 07:01:06 +08:00
return
}
}
2023-12-10 23:53:38 +08:00
final . Response = sb . String ( )
c . JSON ( http . StatusOK , final )
2023-12-05 07:01:06 +08:00
return
}
streamResponse ( c , ch )
}
func EmbeddingHandler ( c * gin . Context ) {
loaded . mu . Lock ( )
defer loaded . mu . Unlock ( )
var req api . EmbeddingRequest
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
return
}
if req . Model == "" {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
return
}
2024-01-04 01:01:42 +08:00
model , err := GetModel ( req . Model )
2023-12-05 07:01:06 +08:00
if err != nil {
2023-12-06 03:57:33 +08:00
var pErr * fs . PathError
2024-01-04 01:01:42 +08:00
if errors . As ( err , & pErr ) {
2023-12-06 03:57:33 +08:00
c . JSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "model '%s' not found, try pulling it first" , req . Model ) } )
2024-01-04 01:01:42 +08:00
return
}
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
opts , err := modelOptions ( model , req . Options )
if err != nil {
if errors . Is ( err , api . ErrInvalidOpts ) {
2023-12-06 03:57:33 +08:00
c . JSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2024-01-04 01:01:42 +08:00
return
2023-12-06 03:57:33 +08:00
}
2024-01-04 01:01:42 +08:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2024-01-27 06:28:02 +08:00
var sessionDuration time . Duration
if req . KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req . KeepAlive . Duration
}
2024-01-04 01:01:42 +08:00
if err := load ( c , model , opts , sessionDuration ) ; err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
2023-12-05 07:01:06 +08:00
return
}
2023-10-19 22:39:58 +08:00
if ! loaded . Options . EmbeddingOnly {
2023-08-09 03:13:22 +08:00
c . JSON ( http . StatusBadRequest , gin . H { "error" : "embedding option must be set to true" } )
return
}
2023-10-19 22:39:58 +08:00
embedding , err := loaded . runner . Embedding ( c . Request . Context ( ) , req . Prompt )
2023-08-09 03:13:22 +08:00
if err != nil {
2024-01-19 02:52:01 +08:00
slog . Info ( fmt . Sprintf ( "embedding generation failed: %v" , err ) )
2023-08-09 03:13:22 +08:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "failed to generate embedding" } )
return
}
resp := api . EmbeddingResponse {
Embedding : embedding ,
}
c . JSON ( http . StatusOK , resp )
}
2023-07-21 07:09:23 +08:00
func PullModelHandler ( c * gin . Context ) {
2023-07-12 02:54:22 +08:00
var req api . PullRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-07-12 02:54:22 +08:00
return
}
2024-01-12 06:07:54 +08:00
var model string
if req . Model != "" {
model = req . Model
} else if req . Name != "" {
model = req . Name
} else {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
2023-10-19 06:56:34 +08:00
return
}
2023-07-17 08:02:22 +08:00
ch := make ( chan any )
go func ( ) {
defer close ( ch )
2023-07-19 09:51:30 +08:00
fn := func ( r api . ProgressResponse ) {
ch <- r
2023-07-17 08:02:22 +08:00
}
2023-07-19 09:51:30 +08:00
2023-07-22 06:42:19 +08:00
regOpts := & RegistryOptions {
Insecure : req . Insecure ,
}
2023-07-26 05:08:51 +08:00
ctx , cancel := context . WithCancel ( c . Request . Context ( ) )
defer cancel ( )
2024-01-12 06:07:54 +08:00
if err := PullModel ( ctx , model , regOpts , fn ) ; err != nil {
2023-07-21 03:12:08 +08:00
ch <- gin . H { "error" : err . Error ( ) }
2023-07-17 08:02:22 +08:00
}
} ( )
2023-10-12 00:54:27 +08:00
if req . Stream != nil && ! * req . Stream {
waitForStream ( c , ch )
return
}
2023-07-17 08:02:22 +08:00
streamResponse ( c , ch )
}
2023-07-21 07:09:23 +08:00
func PushModelHandler ( c * gin . Context ) {
2023-07-17 08:02:22 +08:00
var req api . PushRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-07-12 02:54:22 +08:00
return
}
2023-07-07 01:40:11 +08:00
2024-01-12 06:07:54 +08:00
var model string
if req . Model != "" {
model = req . Model
} else if req . Name != "" {
model = req . Name
} else {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
2023-10-19 06:56:34 +08:00
return
}
2023-07-17 08:02:22 +08:00
ch := make ( chan any )
go func ( ) {
defer close ( ch )
2023-07-19 09:51:30 +08:00
fn := func ( r api . ProgressResponse ) {
ch <- r
2023-07-17 08:02:22 +08:00
}
2023-07-19 09:51:30 +08:00
2023-07-22 06:42:19 +08:00
regOpts := & RegistryOptions {
Insecure : req . Insecure ,
}
2023-10-10 01:24:27 +08:00
ctx , cancel := context . WithCancel ( c . Request . Context ( ) )
defer cancel ( )
2024-01-12 06:07:54 +08:00
if err := PushModel ( ctx , model , regOpts , fn ) ; err != nil {
2023-07-21 03:12:08 +08:00
ch <- gin . H { "error" : err . Error ( ) }
2023-07-17 08:02:22 +08:00
}
} ( )
2023-10-12 00:54:27 +08:00
if req . Stream != nil && ! * req . Stream {
waitForStream ( c , ch )
return
}
2023-07-17 08:02:22 +08:00
streamResponse ( c , ch )
}
2023-07-21 07:09:23 +08:00
func CreateModelHandler ( c * gin . Context ) {
2023-07-17 08:02:22 +08:00
var req api . CreateRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-07-13 10:07:15 +08:00
return
2023-07-17 08:02:22 +08:00
}
2024-01-12 06:07:54 +08:00
var model string
if req . Model != "" {
model = req . Model
} else if req . Name != "" {
model = req . Name
} else {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
2023-10-19 06:56:34 +08:00
return
}
2024-01-12 06:07:54 +08:00
if err := ParseModelPath ( model ) . Validate ( ) ; err != nil {
2023-11-30 04:54:29 +08:00
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-11-30 02:33:45 +08:00
return
}
2023-11-15 05:45:07 +08:00
if req . Path == "" && req . Modelfile == "" {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "path or modelfile are required" } )
2023-11-15 04:30:34 +08:00
return
}
2023-11-15 05:45:07 +08:00
var modelfile io . Reader = strings . NewReader ( req . Modelfile )
if req . Path != "" && req . Modelfile == "" {
2023-11-22 04:43:17 +08:00
mf , err := os . Open ( req . Path )
2023-11-15 05:45:07 +08:00
if err != nil {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : fmt . Sprintf ( "error reading modelfile: %s" , err ) } )
return
}
2023-11-22 04:43:17 +08:00
defer mf . Close ( )
2023-11-15 05:45:07 +08:00
2023-11-22 04:43:17 +08:00
modelfile = mf
2023-11-15 05:45:07 +08:00
}
2023-11-15 04:30:34 +08:00
commands , err := parser . Parse ( modelfile )
if err != nil {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
return
}
2023-07-12 02:54:22 +08:00
ch := make ( chan any )
2023-07-15 05:15:53 +08:00
go func ( ) {
defer close ( ch )
2023-07-26 02:25:13 +08:00
fn := func ( resp api . ProgressResponse ) {
ch <- resp
2023-07-17 08:02:22 +08:00
}
2023-07-26 05:08:51 +08:00
ctx , cancel := context . WithCancel ( c . Request . Context ( ) )
defer cancel ( )
2024-01-12 06:07:54 +08:00
if err := CreateModel ( ctx , model , filepath . Dir ( req . Path ) , commands , fn ) ; err != nil {
2023-07-21 03:12:08 +08:00
ch <- gin . H { "error" : err . Error ( ) }
2023-07-17 08:02:22 +08:00
}
2023-07-15 05:15:53 +08:00
} ( )
2023-07-08 06:29:17 +08:00
2023-10-12 00:54:27 +08:00
if req . Stream != nil && ! * req . Stream {
waitForStream ( c , ch )
return
}
2023-07-15 05:15:53 +08:00
streamResponse ( c , ch )
2023-07-06 03:37:33 +08:00
}
2023-07-21 07:09:23 +08:00
func DeleteModelHandler ( c * gin . Context ) {
var req api . DeleteRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-07-21 07:09:23 +08:00
return
}
2024-01-12 06:07:54 +08:00
var model string
if req . Model != "" {
model = req . Model
} else if req . Name != "" {
model = req . Name
} else {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
2023-10-19 06:56:34 +08:00
return
}
2024-01-12 06:07:54 +08:00
if err := DeleteModel ( model ) ; err != nil {
2023-07-22 14:02:12 +08:00
if os . IsNotExist ( err ) {
2024-01-12 06:07:54 +08:00
c . JSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "model '%s' not found" , model ) } )
2023-07-22 14:02:12 +08:00
} else {
2023-07-21 07:09:23 +08:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
}
2023-07-22 14:02:12 +08:00
return
}
2023-09-27 08:28:14 +08:00
manifestsPath , err := GetManifestPath ( )
if err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
if err := PruneDirectory ( manifestsPath ) ; err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-09-12 02:46:35 +08:00
c . JSON ( http . StatusOK , nil )
2023-07-21 07:09:23 +08:00
}
2023-09-07 02:04:17 +08:00
func ShowModelHandler ( c * gin . Context ) {
var req api . ShowRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-09-07 02:04:17 +08:00
return
}
2024-01-12 06:07:54 +08:00
if req . Model != "" {
2024-01-19 07:36:50 +08:00
// noop
2024-01-12 06:07:54 +08:00
} else if req . Name != "" {
2024-01-19 07:36:50 +08:00
req . Model = req . Name
2024-01-12 06:07:54 +08:00
} else {
2024-01-05 09:23:11 +08:00
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
2023-10-19 06:56:34 +08:00
return
}
2024-01-05 09:23:11 +08:00
resp , err := GetModelInfo ( req )
2023-09-07 02:04:17 +08:00
if err != nil {
if os . IsNotExist ( err ) {
2024-01-19 07:36:50 +08:00
c . JSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "model '%s' not found" , req . Model ) } )
2023-09-07 02:04:17 +08:00
} else {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
}
return
}
c . JSON ( http . StatusOK , resp )
}
2024-01-05 09:23:11 +08:00
func GetModelInfo ( req api . ShowRequest ) ( * api . ShowResponse , error ) {
model , err := GetModel ( req . Model )
2023-09-07 02:04:17 +08:00
if err != nil {
return nil , err
}
2023-12-12 05:56:22 +08:00
modelDetails := api . ModelDetails {
2024-01-26 04:12:36 +08:00
ParentModel : model . ParentModel ,
2023-12-12 05:56:22 +08:00
Format : model . Config . ModelFormat ,
Family : model . Config . ModelFamily ,
Families : model . Config . ModelFamilies ,
ParameterSize : model . Config . ModelType ,
QuantizationLevel : model . Config . FileType ,
}
2024-01-05 09:23:11 +08:00
if req . System != "" {
model . System = req . System
}
if req . Template != "" {
model . Template = req . Template
}
2024-01-26 04:12:36 +08:00
msgs := make ( [ ] api . Message , 0 )
for _ , msg := range model . Messages {
msgs = append ( msgs , api . Message { Role : msg . Role , Content : msg . Content } )
}
2023-09-07 02:04:17 +08:00
resp := & api . ShowResponse {
License : strings . Join ( model . License , "\n" ) ,
System : model . System ,
Template : model . Template ,
2023-12-12 05:56:22 +08:00
Details : modelDetails ,
2024-01-26 04:12:36 +08:00
Messages : msgs ,
2023-09-07 02:04:17 +08:00
}
var params [ ] string
cs := 30
for k , v := range model . Options {
switch val := v . ( type ) {
case [ ] interface { } :
for _ , nv := range val {
2024-01-17 02:34:44 +08:00
params = append ( params , fmt . Sprintf ( "%-*s %#v" , cs , k , nv ) )
2023-09-07 02:04:17 +08:00
}
2024-01-17 02:34:44 +08:00
default :
params = append ( params , fmt . Sprintf ( "%-*s %#v" , cs , k , v ) )
2023-09-07 02:04:17 +08:00
}
}
resp . Parameters = strings . Join ( params , "\n" )
2024-01-05 09:23:11 +08:00
for k , v := range req . Options {
if _ , ok := req . Options [ k ] ; ok {
model . Options [ k ] = v
}
}
mf , err := ShowModelfile ( model )
if err != nil {
return nil , err
}
resp . Modelfile = mf
2023-09-07 02:04:17 +08:00
return resp , nil
}
2023-07-21 07:09:23 +08:00
func ListModelsHandler ( c * gin . Context ) {
2023-10-18 01:02:43 +08:00
models := make ( [ ] api . ModelResponse , 0 )
2023-12-16 07:50:51 +08:00
manifestsPath , err := GetManifestPath ( )
2023-07-19 00:09:45 +08:00
if err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-08-31 02:14:12 +08:00
2023-12-12 05:56:22 +08:00
modelResponse := func ( modelName string ) ( api . ModelResponse , error ) {
model , err := GetModel ( modelName )
if err != nil {
return api . ModelResponse { } , err
}
modelDetails := api . ModelDetails {
Format : model . Config . ModelFormat ,
Family : model . Config . ModelFamily ,
Families : model . Config . ModelFamilies ,
ParameterSize : model . Config . ModelType ,
QuantizationLevel : model . Config . FileType ,
}
return api . ModelResponse {
2024-01-19 06:32:55 +08:00
Model : model . ShortName ,
2023-12-12 05:56:22 +08:00
Name : model . ShortName ,
Size : model . Size ,
Digest : model . Digest ,
Details : modelDetails ,
} , nil
}
2023-08-31 02:14:12 +08:00
walkFunc := func ( path string , info os . FileInfo , _ error ) error {
2023-07-19 00:09:45 +08:00
if ! info . IsDir ( ) {
2023-12-16 07:50:51 +08:00
path , tag := filepath . Split ( path )
model := strings . Trim ( strings . TrimPrefix ( path , manifestsPath ) , string ( os . PathSeparator ) )
modelPath := strings . Join ( [ ] string { model , tag } , ":" )
canonicalModelPath := strings . ReplaceAll ( modelPath , string ( os . PathSeparator ) , "/" )
2023-08-22 12:56:56 +08:00
2023-12-16 07:50:51 +08:00
resp , err := modelResponse ( canonicalModelPath )
2023-07-19 00:09:45 +08:00
if err != nil {
2024-01-19 02:52:01 +08:00
slog . Info ( fmt . Sprintf ( "skipping file: %s" , canonicalModelPath ) )
2023-12-16 06:07:34 +08:00
// nolint: nilerr
2023-07-19 03:39:08 +08:00
return nil
2023-07-19 00:09:45 +08:00
}
2023-08-31 02:14:12 +08:00
2023-12-12 05:56:22 +08:00
resp . ModifiedAt = info . ModTime ( )
models = append ( models , resp )
2023-07-19 00:09:45 +08:00
}
2023-08-31 02:14:12 +08:00
2023-07-19 00:09:45 +08:00
return nil
2023-08-31 02:14:12 +08:00
}
2023-12-16 07:50:51 +08:00
if err := filepath . Walk ( manifestsPath , walkFunc ) ; err != nil {
2023-07-19 00:09:45 +08:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-07-20 06:00:28 +08:00
c . JSON ( http . StatusOK , api . ListResponse { Models : models } )
2023-07-19 00:09:45 +08:00
}
2023-07-24 23:27:28 +08:00
func CopyModelHandler ( c * gin . Context ) {
var req api . CopyRequest
2023-10-19 07:08:42 +08:00
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2023-07-24 23:27:28 +08:00
return
}
2023-10-19 06:56:34 +08:00
if req . Source == "" || req . Destination == "" {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "source add destination are required" } )
return
}
2023-11-30 04:54:29 +08:00
if err := ParseModelPath ( req . Destination ) . Validate ( ) ; err != nil {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
return
}
2023-07-24 23:27:28 +08:00
if err := CopyModel ( req . Source , req . Destination ) ; err != nil {
if os . IsNotExist ( err ) {
c . JSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "model '%s' not found" , req . Source ) } )
} else {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
}
return
}
}
2023-11-16 02:59:38 +08:00
func HeadBlobHandler ( c * gin . Context ) {
2023-11-15 06:07:40 +08:00
path , err := GetBlobsPath ( c . Param ( "digest" ) )
if err != nil {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
return
}
if _ , err := os . Stat ( path ) ; err != nil {
c . AbortWithStatusJSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "blob %q not found" , c . Param ( "digest" ) ) } )
return
}
2023-11-16 05:55:37 +08:00
c . Status ( http . StatusOK )
2023-11-15 06:07:40 +08:00
}
func CreateBlobHandler ( c * gin . Context ) {
2023-11-25 04:01:23 +08:00
layer , err := NewLayer ( c . Request . Body , "" )
2023-11-18 07:21:57 +08:00
if err != nil {
c . AbortWithStatusJSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-11-25 04:01:23 +08:00
if layer . Digest != c . Param ( "digest" ) {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : fmt . Sprintf ( "digest mismatch, expected %q, got %q" , c . Param ( "digest" ) , layer . Digest ) } )
2023-11-15 06:07:40 +08:00
return
}
2023-11-25 04:01:23 +08:00
if _ , err := layer . Commit ( ) ; err != nil {
2023-11-15 06:07:40 +08:00
c . AbortWithStatusJSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-11-16 05:55:37 +08:00
c . Status ( http . StatusCreated )
2023-11-15 06:07:40 +08:00
}
2023-09-22 00:42:16 +08:00
var defaultAllowOrigins = [ ] string {
"localhost" ,
"127.0.0.1" ,
"0.0.0.0" ,
}
2023-12-15 08:47:40 +08:00
func NewServer ( ) ( * Server , error ) {
workDir , err := os . MkdirTemp ( "" , "ollama" )
if err != nil {
return nil , err
}
2023-10-30 23:10:18 +08:00
2023-12-15 08:47:40 +08:00
return & Server {
WorkDir : workDir ,
} , nil
}
2023-10-30 23:10:18 +08:00
2023-12-15 08:47:40 +08:00
func ( s * Server ) GenerateRoutes ( ) http . Handler {
var origins [ ] string
if o := os . Getenv ( "OLLAMA_ORIGINS" ) ; o != "" {
origins = strings . Split ( o , "," )
2023-10-30 23:10:18 +08:00
}
2023-07-22 09:01:24 +08:00
config := cors . DefaultConfig ( )
config . AllowWildcard = true
2024-01-05 09:55:47 +08:00
config . AllowBrowserExtensions = true
2023-09-22 00:42:16 +08:00
2023-12-15 08:47:40 +08:00
config . AllowOrigins = origins
2023-09-22 00:42:16 +08:00
for _ , allowOrigin := range defaultAllowOrigins {
config . AllowOrigins = append ( config . AllowOrigins ,
fmt . Sprintf ( "http://%s" , allowOrigin ) ,
fmt . Sprintf ( "https://%s" , allowOrigin ) ,
fmt . Sprintf ( "http://%s:*" , allowOrigin ) ,
fmt . Sprintf ( "https://%s:*" , allowOrigin ) ,
)
}
2023-07-22 09:01:24 +08:00
2023-07-06 03:37:33 +08:00
r := gin . Default ( )
2023-09-22 03:38:49 +08:00
r . Use (
cors . New ( config ) ,
func ( c * gin . Context ) {
2023-12-15 08:47:40 +08:00
c . Set ( "workDir" , s . WorkDir )
2023-09-22 03:38:49 +08:00
c . Next ( )
} ,
)
2023-07-06 03:37:33 +08:00
2023-07-21 07:09:23 +08:00
r . POST ( "/api/pull" , PullModelHandler )
r . POST ( "/api/generate" , GenerateHandler )
2023-12-06 03:57:33 +08:00
r . POST ( "/api/chat" , ChatHandler )
2023-08-09 03:13:22 +08:00
r . POST ( "/api/embeddings" , EmbeddingHandler )
2023-07-21 07:09:23 +08:00
r . POST ( "/api/create" , CreateModelHandler )
r . POST ( "/api/push" , PushModelHandler )
2023-07-24 23:27:28 +08:00
r . POST ( "/api/copy" , CopyModelHandler )
2023-07-21 07:09:23 +08:00
r . DELETE ( "/api/delete" , DeleteModelHandler )
2023-09-07 02:04:17 +08:00
r . POST ( "/api/show" , ShowModelHandler )
2023-11-15 06:07:40 +08:00
r . POST ( "/api/blobs/:digest" , CreateBlobHandler )
2023-11-16 07:22:12 +08:00
r . HEAD ( "/api/blobs/:digest" , HeadBlobHandler )
2023-07-04 03:22:44 +08:00
2024-02-08 06:24:29 +08:00
// Compatibility endpoints
r . POST ( "/v1/chat/completions" , openai . Middleware ( ) , ChatHandler )
2023-09-22 07:38:03 +08:00
for _ , method := range [ ] string { http . MethodGet , http . MethodHead } {
r . Handle ( method , "/" , func ( c * gin . Context ) {
c . String ( http . StatusOK , "Ollama is running" )
} )
r . Handle ( method , "/api/tags" , ListModelsHandler )
2023-10-13 06:45:07 +08:00
r . Handle ( method , "/api/version" , func ( c * gin . Context ) {
c . JSON ( http . StatusOK , gin . H { "version" : version . Version } )
} )
2023-09-22 07:38:03 +08:00
}
2023-12-15 08:47:40 +08:00
return r
}
func Serve ( ln net . Listener ) error {
2024-02-01 06:59:32 +08:00
level := slog . LevelInfo
2024-01-19 02:52:01 +08:00
if debug := os . Getenv ( "OLLAMA_DEBUG" ) ; debug != "" {
2024-02-01 06:59:32 +08:00
level = slog . LevelDebug
2024-01-19 02:52:01 +08:00
}
2024-02-01 06:59:32 +08:00
handler := slog . NewTextHandler ( os . Stderr , & slog . HandlerOptions {
Level : level ,
AddSource : true ,
ReplaceAttr : func ( _ [ ] string , attr slog . Attr ) slog . Attr {
if attr . Key == slog . SourceKey {
source := attr . Value . Any ( ) . ( * slog . Source )
source . File = filepath . Base ( source . File )
}
return attr
} ,
} )
slog . SetDefault ( slog . New ( handler ) )
2023-12-15 08:47:40 +08:00
if noprune := os . Getenv ( "OLLAMA_NOPRUNE" ) ; noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers ( ) ; err != nil {
return err
}
manifestsPath , err := GetManifestPath ( )
if err != nil {
return err
}
if err := PruneDirectory ( manifestsPath ) ; err != nil {
return err
}
}
s , err := NewServer ( )
if err != nil {
return err
}
r := s . GenerateRoutes ( )
2024-01-19 02:52:01 +08:00
slog . Info ( fmt . Sprintf ( "Listening on %s (version %s)" , ln . Addr ( ) , version . Version ) )
2023-12-15 08:47:40 +08:00
srvr := & http . Server {
2023-07-04 03:22:44 +08:00
Handler : r ,
}
2023-08-31 04:35:03 +08:00
// listen for a ctrl+c and stop any loaded llm
signals := make ( chan os . Signal , 1 )
2023-09-22 03:38:49 +08:00
signal . Notify ( signals , syscall . SIGINT , syscall . SIGTERM )
2023-08-31 04:35:03 +08:00
go func ( ) {
<- signals
2023-10-19 22:39:58 +08:00
if loaded . runner != nil {
loaded . runner . Close ( )
2023-09-23 02:41:52 +08:00
}
2023-12-15 08:47:40 +08:00
os . RemoveAll ( s . WorkDir )
2023-08-31 04:35:03 +08:00
os . Exit ( 0 )
} ( )
2023-11-30 03:00:37 +08:00
if err := llm . Init ( s . WorkDir ) ; err != nil {
return fmt . Errorf ( "unable to initialize llm library %w" , err )
}
if runtime . GOOS == "linux" { // TODO - windows too
2023-09-12 23:04:35 +08:00
// check compatibility to log warnings
2023-11-30 03:00:37 +08:00
if _ , err := gpu . CheckVRAM ( ) ; err != nil {
2024-01-19 02:52:01 +08:00
slog . Info ( err . Error ( ) )
2023-09-12 23:04:35 +08:00
}
}
2023-12-15 08:47:40 +08:00
return srvr . Serve ( ln )
2023-07-04 03:22:44 +08:00
}
2023-07-07 01:40:11 +08:00
2023-10-12 00:54:27 +08:00
func waitForStream ( c * gin . Context , ch chan interface { } ) {
c . Header ( "Content-Type" , "application/json" )
for resp := range ch {
switch r := resp . ( type ) {
case api . ProgressResponse :
if r . Status == "success" {
c . JSON ( http . StatusOK , r )
return
}
case gin . H :
if errorMsg , ok := r [ "error" ] . ( string ) ; ok {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : errorMsg } )
return
} else {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected error format in progress response" } )
return
}
default :
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected progress response" } )
return
}
}
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected end of progress response" } )
}
2023-07-15 05:15:53 +08:00
func streamResponse ( c * gin . Context , ch chan any ) {
2023-08-09 12:38:10 +08:00
c . Header ( "Content-Type" , "application/x-ndjson" )
2023-07-12 02:54:22 +08:00
c . Stream ( func ( w io . Writer ) bool {
val , ok := <- ch
if ! ok {
return false
}
bts , err := json . Marshal ( val )
if err != nil {
2024-01-19 02:52:01 +08:00
slog . Info ( fmt . Sprintf ( "streamResponse: json.Marshal failed with %s" , err ) )
2023-07-12 02:54:22 +08:00
return false
}
2023-09-30 12:45:52 +08:00
// Delineate chunks with new-line delimiter
2023-07-12 02:54:22 +08:00
bts = append ( bts , '\n' )
if _ , err := w . Write ( bts ) ; err != nil {
2024-01-19 02:52:01 +08:00
slog . Info ( fmt . Sprintf ( "streamResponse: w.Write failed with %s" , err ) )
2023-07-12 02:54:22 +08:00
return false
}
return true
} )
}
2023-12-06 03:57:33 +08:00
func ChatHandler ( c * gin . Context ) {
loaded . mu . Lock ( )
defer loaded . mu . Unlock ( )
checkpointStart := time . Now ( )
var req api . ChatRequest
err := c . ShouldBindJSON ( & req )
switch {
case errors . Is ( err , io . EOF ) :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "missing request body" } )
return
case err != nil :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
return
}
// validate the request
switch {
case req . Model == "" :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "model is required" } )
return
case len ( req . Format ) > 0 && req . Format != "json" :
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "format must be json" } )
return
}
2024-02-13 03:16:20 +08:00
for _ , msg := range req . Messages {
for _ , img := range msg . Images {
if ! isSupportedImageType ( img ) {
c . AbortWithStatusJSON ( http . StatusBadRequest , gin . H { "error" : "unsupported image format" } )
return
}
}
}
2024-01-04 01:01:42 +08:00
model , err := GetModel ( req . Model )
2023-12-06 03:57:33 +08:00
if err != nil {
var pErr * fs . PathError
2024-01-04 01:01:42 +08:00
if errors . As ( err , & pErr ) {
2023-12-06 03:57:33 +08:00
c . JSON ( http . StatusNotFound , gin . H { "error" : fmt . Sprintf ( "model '%s' not found, try pulling it first" , req . Model ) } )
2024-01-04 01:01:42 +08:00
return
}
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
opts , err := modelOptions ( model , req . Options )
if err != nil {
if errors . Is ( err , api . ErrInvalidOpts ) {
2023-12-06 03:57:33 +08:00
c . JSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
2024-01-04 01:01:42 +08:00
return
2023-12-06 03:57:33 +08:00
}
2024-01-04 01:01:42 +08:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2024-01-27 06:28:02 +08:00
var sessionDuration time . Duration
if req . KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req . KeepAlive . Duration
}
2024-01-04 01:01:42 +08:00
if err := load ( c , model , opts , sessionDuration ) ; err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
2023-12-06 03:57:33 +08:00
return
}
checkpointLoaded := time . Now ( )
2024-01-31 04:59:29 +08:00
chat , err := model . ChatPrompts ( req . Messages )
2023-12-06 03:57:33 +08:00
if err != nil {
c . JSON ( http . StatusBadRequest , gin . H { "error" : err . Error ( ) } )
return
}
2024-02-01 09:39:38 +08:00
prompt , images , err := trimmedPrompt ( c . Request . Context ( ) , chat , model )
2024-01-31 04:59:29 +08:00
if err != nil {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : err . Error ( ) } )
return
}
2023-12-06 03:57:33 +08:00
2024-02-08 08:30:33 +08:00
// an empty request loads the model
if len ( prompt ) == 0 {
resp := api . ChatResponse {
CreatedAt : time . Now ( ) . UTC ( ) ,
Model : req . Model ,
Done : true ,
Message : api . Message { Role : "assistant" } ,
}
c . JSON ( http . StatusOK , resp )
return
}
2024-02-01 08:47:26 +08:00
slog . Debug ( "chat handler" , "prompt" , prompt )
2024-01-29 07:22:35 +08:00
2023-12-06 03:57:33 +08:00
ch := make ( chan any )
go func ( ) {
defer close ( ch )
fn := func ( r llm . PredictResult ) {
// Update model expiration
loaded . expireAt = time . Now ( ) . Add ( sessionDuration )
loaded . expireTimer . Reset ( sessionDuration )
resp := api . ChatResponse {
2023-12-11 00:42:15 +08:00
Model : req . Model ,
2023-12-15 01:15:50 +08:00
CreatedAt : time . Now ( ) . UTC ( ) ,
2023-12-19 03:23:38 +08:00
Message : api . Message { Role : "assistant" , Content : r . Content } ,
2023-12-06 03:57:33 +08:00
Done : r . Done ,
Metrics : api . Metrics {
PromptEvalCount : r . PromptEvalCount ,
PromptEvalDuration : r . PromptEvalDuration ,
EvalCount : r . EvalCount ,
EvalDuration : r . EvalDuration ,
} ,
}
2023-12-15 01:15:50 +08:00
if r . Done {
resp . TotalDuration = time . Since ( checkpointStart )
resp . LoadDuration = checkpointLoaded . Sub ( checkpointStart )
2023-12-06 03:57:33 +08:00
}
ch <- resp
}
// Start prediction
predictReq := llm . PredictOpts {
2024-01-04 01:01:42 +08:00
Prompt : prompt ,
Format : req . Format ,
2024-02-01 11:18:25 +08:00
Images : images ,
2024-01-04 01:01:42 +08:00
Options : opts ,
2023-12-06 03:57:33 +08:00
}
if err := loaded . runner . Predict ( c . Request . Context ( ) , predictReq , fn ) ; err != nil {
ch <- gin . H { "error" : err . Error ( ) }
}
} ( )
if req . Stream != nil && ! * req . Stream {
2023-12-10 23:53:38 +08:00
// Accumulate responses into the final response
var final api . ChatResponse
2023-12-06 03:57:33 +08:00
var sb strings . Builder
for resp := range ch {
2023-12-10 23:53:38 +08:00
switch r := resp . ( type ) {
case api . ChatResponse :
2023-12-19 03:23:38 +08:00
sb . WriteString ( r . Message . Content )
2023-12-10 23:53:38 +08:00
final = r
case gin . H :
if errorMsg , ok := r [ "error" ] . ( string ) ; ok {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : errorMsg } )
return
} else {
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected error format in response" } )
return
}
default :
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "unexpected error" } )
return
2023-12-06 03:57:33 +08:00
}
}
2023-12-10 23:53:38 +08:00
2023-12-19 03:23:38 +08:00
final . Message = api . Message { Role : "assistant" , Content : sb . String ( ) }
2023-12-10 23:53:38 +08:00
c . JSON ( http . StatusOK , final )
2023-12-06 03:57:33 +08:00
return
}
streamResponse ( c , ch )
}
2024-01-31 04:59:29 +08:00
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
type promptInfo struct {
vars PromptVars
tokenLen int
}
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
2024-02-01 11:18:25 +08:00
func trimmedPrompt ( ctx context . Context , chat * ChatHistory , model * Model ) ( string , [ ] llm . ImageData , error ) {
2024-01-31 04:59:29 +08:00
if len ( chat . Prompts ) == 0 {
2024-02-01 09:39:38 +08:00
return "" , nil , nil
2024-01-31 04:59:29 +08:00
}
var promptsToAdd [ ] promptInfo
var totalTokenLength int
var systemPromptIncluded bool
2024-02-01 11:18:25 +08:00
var images [ ] llm . ImageData
2024-01-31 04:59:29 +08:00
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len ( chat . Prompts ) - 1 ; i >= 0 ; i -- {
2024-02-02 03:21:17 +08:00
prompt := chat . Prompts [ i ]
promptText , err := promptString ( model , prompt , i == len ( chat . Prompts ) - 1 )
2024-01-31 04:59:29 +08:00
if err != nil {
2024-02-01 09:39:38 +08:00
return "" , nil , err
2024-01-31 04:59:29 +08:00
}
encodedTokens , err := loaded . runner . Encode ( ctx , promptText )
if err != nil {
2024-02-01 09:39:38 +08:00
return "" , nil , err
2024-01-31 04:59:29 +08:00
}
if totalTokenLength + len ( encodedTokens ) > loaded . NumCtx && i != len ( chat . Prompts ) - 1 {
break // reached max context length, stop adding more prompts
}
2024-02-02 03:21:17 +08:00
for j := range prompt . Images {
if totalTokenLength + 768 > loaded . NumCtx {
// this decreases the token length but overestimating is fine
prompt . Prompt = strings . ReplaceAll ( prompt . Prompt , fmt . Sprintf ( " [img-%d]" , prompt . Images [ j ] . ID ) , "" )
continue
}
2024-02-01 09:39:38 +08:00
2024-02-02 03:21:17 +08:00
totalTokenLength += 768
images = append ( images , prompt . Images [ j ] )
}
2024-02-02 01:50:48 +08:00
2024-01-31 04:59:29 +08:00
totalTokenLength += len ( encodedTokens )
2024-02-02 03:21:17 +08:00
systemPromptIncluded = systemPromptIncluded || prompt . System != ""
promptsToAdd = append ( promptsToAdd , promptInfo { vars : prompt , tokenLen : len ( encodedTokens ) } )
2024-01-31 04:59:29 +08:00
}
// ensure the system prompt is included, if not already
if chat . LastSystem != "" && ! systemPromptIncluded {
var err error
promptsToAdd , err = includeSystemPrompt ( ctx , chat . LastSystem , totalTokenLength , promptsToAdd )
if err != nil {
2024-02-01 09:39:38 +08:00
return "" , nil , err
2024-01-31 04:59:29 +08:00
}
}
promptsToAdd [ len ( promptsToAdd ) - 1 ] . vars . First = true
// construct the final prompt string from the prompts which fit within the context window
var result string
for i , prompt := range promptsToAdd {
promptText , err := promptString ( model , prompt . vars , i == 0 )
if err != nil {
2024-02-01 09:39:38 +08:00
return "" , nil , err
2024-01-31 04:59:29 +08:00
}
result = promptText + result
}
2024-02-01 11:18:25 +08:00
2024-02-01 09:39:38 +08:00
return result , images , nil
2024-01-31 04:59:29 +08:00
}
// promptString applies the model template to the prompt
func promptString ( model * Model , vars PromptVars , isMostRecent bool ) ( string , error ) {
if isMostRecent {
p , err := model . PreResponsePrompt ( vars )
if err != nil {
return "" , fmt . Errorf ( "pre-response template: %w" , err )
}
return p , nil
}
p , err := Prompt ( model . Template , vars )
if err != nil {
return "" , err
}
return p , nil
}
// includeSystemPrompt adjusts the prompts to include the system prompt.
func includeSystemPrompt ( ctx context . Context , systemPrompt string , totalTokenLength int , promptsToAdd [ ] promptInfo ) ( [ ] promptInfo , error ) {
systemTokens , err := loaded . runner . Encode ( ctx , systemPrompt )
if err != nil {
return nil , err
}
for i := len ( promptsToAdd ) - 1 ; i >= 0 ; i -- {
if totalTokenLength + len ( systemTokens ) <= loaded . NumCtx {
promptsToAdd [ i ] . vars . System = systemPrompt
return promptsToAdd [ : i + 1 ] , nil
}
totalTokenLength -= promptsToAdd [ i ] . tokenLen
}
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
recent := promptsToAdd [ len ( promptsToAdd ) - 1 ]
recent . vars . System = systemPrompt
return [ ] promptInfo { recent } , nil
}