mirror of https://github.com/ollama/ollama.git
1347 lines
37 KiB
Go
1347 lines
37 KiB
Go
package ollamarunner
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"hash/maphash"
|
|
"image"
|
|
"log"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"reflect"
|
|
"regexp"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"golang.org/x/image/bmp"
|
|
"golang.org/x/sync/semaphore"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/fs/ggml"
|
|
"github.com/ollama/ollama/llm"
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn/pooling"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
"github.com/ollama/ollama/runner/common"
|
|
"github.com/ollama/ollama/sample"
|
|
|
|
_ "github.com/ollama/ollama/model/models"
|
|
)
|
|
|
|
type Sequence struct {
|
|
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
|
// multimodal embeddings
|
|
ctxs []ml.Context
|
|
|
|
// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
|
|
mmStore multimodalStore
|
|
|
|
// batch index
|
|
iBatch int
|
|
|
|
// prompt inputs left to evaluate
|
|
inputs []*input.Input
|
|
|
|
// inputs that have been added to a batch but not yet submitted to Forward
|
|
pendingInputs []*input.Input
|
|
|
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
|
pendingResponses []string
|
|
|
|
// input cache being used by this sequence
|
|
cache *InputCacheSlot
|
|
|
|
// channel to send responses over
|
|
responses chan string
|
|
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
quit chan bool
|
|
|
|
// number of tokens to predict
|
|
numPredict int
|
|
|
|
// sampler with transforms to run on generated logits
|
|
sampler sample.Sampler
|
|
|
|
// channel to send back the embedding if embedding only
|
|
embedding chan []float32
|
|
|
|
// stop sequences
|
|
stop []string
|
|
|
|
// number of inputs to keep at the beginning when shifting context window
|
|
numKeep int32
|
|
|
|
// true if an embedding are to be returned instead of text generation
|
|
embeddingOnly bool
|
|
|
|
doneReason llm.DoneReason
|
|
|
|
// Metrics
|
|
startProcessingTime time.Time
|
|
startGenerationTime time.Time
|
|
numPredicted int
|
|
numPromptInputs int
|
|
}
|
|
|
|
type NewSequenceParams struct {
|
|
numPredict int
|
|
stop []string
|
|
numKeep int32
|
|
sampler sample.Sampler
|
|
embedding bool
|
|
}
|
|
|
|
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
|
s.ready.Wait()
|
|
|
|
startTime := time.Now()
|
|
|
|
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
|
} else if len(inputs) == 0 {
|
|
return nil, errors.New("no input provided")
|
|
}
|
|
|
|
if params.numKeep < 0 {
|
|
params.numKeep = int32(len(inputs))
|
|
}
|
|
|
|
// Ensure that at least 1 input can be discarded during shift
|
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
|
|
|
if int32(len(inputs)) > s.cache.numCtx {
|
|
discard := int32(len(inputs)) - s.cache.numCtx
|
|
promptStart := params.numKeep + discard
|
|
|
|
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
|
sameBatch := 0
|
|
for i, inp := range inputs {
|
|
if sameBatch > 0 {
|
|
sameBatch--
|
|
|
|
if promptStart == int32(i) {
|
|
promptStart++
|
|
}
|
|
} else if promptStart == int32(i) {
|
|
break
|
|
}
|
|
|
|
if inp.SameBatch != 0 {
|
|
if int32(i) < params.numKeep {
|
|
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
|
|
}
|
|
|
|
sameBatch = inp.SameBatch
|
|
}
|
|
}
|
|
|
|
if promptStart >= int32(len(inputs)) {
|
|
return nil, errors.New("entire prompt removed by truncation")
|
|
}
|
|
|
|
newInputs := inputs[:params.numKeep]
|
|
newInputs = append(newInputs, inputs[promptStart:]...)
|
|
|
|
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
|
inputs = newInputs
|
|
}
|
|
|
|
// TODO(jessegross): Ingest cached history for grammar
|
|
|
|
return &Sequence{
|
|
ctxs: ctxs,
|
|
mmStore: mmStore,
|
|
inputs: inputs,
|
|
numPromptInputs: len(inputs),
|
|
startProcessingTime: startTime,
|
|
numPredict: params.numPredict,
|
|
pendingResponses: make([]string, 0),
|
|
responses: make(chan string, 100),
|
|
quit: make(chan bool, 1),
|
|
embedding: make(chan []float32, 1),
|
|
sampler: params.sampler,
|
|
embeddingOnly: params.embedding,
|
|
stop: params.stop,
|
|
numKeep: params.numKeep,
|
|
}, nil
|
|
}
|
|
|
|
// inputs processes the prompt and images into a list of inputs
|
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
|
// decoding images
|
|
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
|
var inputs []*input.Input
|
|
var ctxs []ml.Context
|
|
var mmStore multimodalStore
|
|
|
|
var parts []string
|
|
var matches [][]string
|
|
|
|
multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
|
|
|
|
if visionModel {
|
|
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
|
parts = re.Split(prompt, -1)
|
|
matches = re.FindAllStringSubmatch(prompt, -1)
|
|
mmStore = newMultimodalStore()
|
|
} else {
|
|
parts = []string{prompt}
|
|
}
|
|
|
|
postTokenize := false
|
|
for i, part := range parts {
|
|
// text - tokenize
|
|
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
for _, t := range tokens {
|
|
inputs = append(inputs, &input.Input{Token: t})
|
|
}
|
|
|
|
// image - decode and store
|
|
if i < len(matches) {
|
|
n, _ := strconv.Atoi(matches[i][1])
|
|
|
|
imageIndex := -1
|
|
for j := range images {
|
|
if images[j].ID == n {
|
|
imageIndex = j
|
|
break
|
|
}
|
|
}
|
|
|
|
if imageIndex < 0 {
|
|
return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
|
|
}
|
|
|
|
ctx := s.model.Backend().NewContext()
|
|
runtime.SetFinalizer(ctx, func(c ml.Context) { c.Close() })
|
|
ctxs = append(ctxs, ctx)
|
|
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
s.multimodalHash.Reset()
|
|
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
|
|
imageHash := s.multimodalHash.Sum64()
|
|
|
|
mmStore.addMultimodal(imageEmbeddings)
|
|
|
|
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
|
postTokenize = true
|
|
}
|
|
}
|
|
|
|
if visionModel && postTokenize {
|
|
var err error
|
|
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
}
|
|
|
|
return inputs, ctxs, mmStore, nil
|
|
}
|
|
|
|
type batchState struct {
|
|
// id provides a counter for trace logging batches
|
|
id int
|
|
|
|
// ctx holds the backend context used for this batch
|
|
ctx ml.Context
|
|
|
|
// modelOutput holds the outputs from this batch
|
|
modelOutput ml.Tensor
|
|
|
|
// batchInputs holds the input token pointers which may start as
|
|
// placeholders later filled in before calling ctx.Compute
|
|
batchInputs []*input.Input
|
|
|
|
// batch contains the inputs for a model forward pass
|
|
batch input.Batch
|
|
|
|
// full set of seqs at the time this batch was initiated
|
|
seqs []*Sequence
|
|
|
|
// Signaled when this batches inputs are ready and compute can proceed
|
|
inputsReadyCh chan struct{}
|
|
|
|
// Signaling when Compute is about to begin on this batch, and
|
|
// seqs have been updated to prepare for the next batch
|
|
computeStartedCh chan struct{}
|
|
|
|
// Signaled when this batches outputs are complete and the next batch can proceed
|
|
outputsReadyCh chan struct{}
|
|
}
|
|
|
|
type Server struct {
|
|
// modelPath is the location of the model to be loaded
|
|
modelPath string
|
|
|
|
// loadMu prevents more than one load attempt from occurring at a time
|
|
loadMu sync.Mutex
|
|
|
|
// lastLoad is the load request from the previous load attempt. Used to
|
|
// detect if we can reuse an existing memory allocation.
|
|
lastLoad llm.LoadRequest
|
|
|
|
// is the server ready to process requests?
|
|
// protects access to model and image
|
|
ready sync.WaitGroup
|
|
|
|
// loaded model
|
|
model model.Model
|
|
|
|
// status for external health reporting - loading, ready to serve, etc.
|
|
status llm.ServerStatus
|
|
|
|
// current progress on loading the model
|
|
progress float32
|
|
|
|
// number of simultaneous requests to handle
|
|
parallel int
|
|
|
|
// maximum number of elements in a batch (per sequence)
|
|
// TODO (jmorganca): make this n_batch
|
|
batchSize int
|
|
|
|
// Used to signal a hard failure during async processing which will panic the runner
|
|
hardErrCh chan error
|
|
|
|
// Simple counter used only for trace logging batches
|
|
batchID int
|
|
|
|
// protects access to everything below this line
|
|
// this is context state needed for decoding
|
|
mu sync.Mutex
|
|
|
|
// indicates that data is ready for processing
|
|
cond *sync.Cond
|
|
|
|
// the list of simultaneous sequences being evaluated
|
|
seqs []*Sequence
|
|
|
|
// seqs can have a maximum of parallel entries, which
|
|
// is enfoced by seqSem
|
|
seqsSem *semaphore.Weighted
|
|
|
|
// KV cache
|
|
cache *InputCache
|
|
|
|
// next sequence for prompt processing to avoid starvation
|
|
nextSeq int
|
|
|
|
// multimodalHash generates hashes for comparing equality
|
|
// of non-text data
|
|
multimodalHash maphash.Hash
|
|
}
|
|
|
|
func (s *Server) allNil() bool {
|
|
for _, item := range s.seqs {
|
|
if item != nil {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func flushPending(seq *Sequence) bool {
|
|
joined := strings.Join(seq.pendingResponses, "")
|
|
seq.pendingResponses = []string{}
|
|
|
|
// Check if there are any partial UTF-8 characters remaining.
|
|
// We already check and queue as we are generating but some may
|
|
// still make it here:
|
|
// - Sequence is ending, e.g. generation limit has been hit
|
|
// - Invalid characters in the middle of a string
|
|
// This is a stricter check to ensure we never output invalid Unicode.
|
|
for !utf8.ValidString(joined) {
|
|
joined = joined[:len(joined)-1]
|
|
}
|
|
|
|
if len(joined) == 0 {
|
|
return true
|
|
}
|
|
|
|
select {
|
|
case seq.responses <- joined:
|
|
return true
|
|
case <-seq.quit:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
|
seq := s.seqs[seqIndex]
|
|
|
|
flushPending(seq)
|
|
seq.doneReason = reason
|
|
close(seq.responses)
|
|
close(seq.embedding)
|
|
seq.cache.InUse = false
|
|
s.seqs[seqIndex] = nil
|
|
s.seqsSem.Release(1)
|
|
}
|
|
|
|
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
|
|
|
func (s *Server) run(ctx context.Context) {
|
|
s.ready.Wait()
|
|
|
|
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
|
|
|
var activeBatch batchState
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case err := <-s.hardErrCh:
|
|
panic(err)
|
|
default:
|
|
var err error
|
|
activeBatch, err = s.forwardBatch(activeBatch)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if supportsAsync {
|
|
go s.computeBatch(activeBatch)
|
|
} else {
|
|
s.computeBatch(activeBatch)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// forwardBatch will calculate a batch.
|
|
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
|
|
// If we have a pending batch still processing, wait until Compute has started
|
|
// before setting up the next batch so the seqs inputs are ready to receive their
|
|
// token values and we get the correct input pointers for the batchInputs
|
|
if pendingBatch.ctx != nil {
|
|
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
|
<-pendingBatch.computeStartedCh
|
|
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
|
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
|
} else {
|
|
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
|
|
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
|
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
|
nextBatch.inputsReadyCh <- struct{}{}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
for s.allNil() {
|
|
s.cond.Wait() // Wait until an item is added
|
|
}
|
|
defer s.mu.Unlock()
|
|
|
|
nextBatch.ctx = s.model.Backend().NewContext()
|
|
defer func() {
|
|
if err != nil {
|
|
nextBatch.ctx.Close()
|
|
nextBatch.ctx = nil
|
|
}
|
|
}()
|
|
nextBatch.id = s.batchID
|
|
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
|
|
nextBatch.computeStartedCh = make(chan struct{}, 1)
|
|
nextBatch.outputsReadyCh = make(chan struct{}, 1)
|
|
|
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
|
var batchInputs []*input.Input
|
|
var batchOutputs []int32
|
|
var batch input.Batch
|
|
|
|
resumeSeq := -1
|
|
seqIdx := s.nextSeq - 1
|
|
for range s.seqs {
|
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
|
seq := s.seqs[seqIdx]
|
|
if seq == nil {
|
|
continue
|
|
}
|
|
|
|
// if past the num predict limit
|
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
|
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
|
nextBatch.seqs[seqIdx] = nil
|
|
continue
|
|
}
|
|
if seq.numPredict == -2 && int32(len(seq.cache.Inputs)) >= s.cache.numCtx {
|
|
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
|
continue
|
|
}
|
|
|
|
if !s.cache.enabled {
|
|
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
|
seq.cache.Inputs = []*input.Input{}
|
|
}
|
|
|
|
batchSize := s.batchSize
|
|
|
|
for i, inp := range seq.inputs {
|
|
// If we are required to put following inputs into a single batch then extend the
|
|
// batch size. Since we are only extending the size the minimum amount possible, this
|
|
// will cause a break if we have existing inputs.
|
|
minBatch := 1 + inp.SameBatch
|
|
if minBatch > batchSize {
|
|
batchSize = minBatch
|
|
}
|
|
|
|
// Stop if the required batch would put us over the total batch size (including tokens
|
|
// added by other sequences). If we haven't been able to add anything yet then pick up
|
|
// here again for the next batch to avoid starvation, though we can opportunistically
|
|
// check if other sequences can still squeeze something in.
|
|
if len(batchInputs)+minBatch > batchSize {
|
|
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
|
resumeSeq = seqIdx
|
|
}
|
|
break
|
|
}
|
|
|
|
// If the sum of our working set (already processed tokens, tokens we added to this
|
|
// batch, required following tokens) exceeds the context size, then trigger a shift
|
|
// now so we don't have to do one later when we can't break the batch.
|
|
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
|
if len(seq.pendingInputs) != 0 {
|
|
break
|
|
}
|
|
|
|
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
if err != nil {
|
|
var reprocess *ErrReprocessInputs
|
|
if errors.As(err, &reprocess) {
|
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
|
// Skip this sequence but continue processing the rest
|
|
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
|
err = nil
|
|
continue
|
|
} else {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
batchInputs = append(batchInputs, seq.inputs[i])
|
|
if inp.Multimodal != nil {
|
|
var mm []input.Multimodal
|
|
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
|
|
if err != nil {
|
|
return
|
|
}
|
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
|
}
|
|
|
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
|
|
|
seq.iBatch = len(batchOutputs)
|
|
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
|
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
|
}
|
|
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
|
}
|
|
|
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
|
}
|
|
|
|
if resumeSeq != -1 {
|
|
s.nextSeq = resumeSeq
|
|
} else {
|
|
s.nextSeq = seqIdx + 1
|
|
}
|
|
|
|
if len(batchInputs) == 0 {
|
|
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
|
nextBatch.ctx.Close()
|
|
nextBatch.ctx = nil
|
|
return
|
|
}
|
|
s.batchID++
|
|
|
|
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
|
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
|
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
|
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to build graph: %w", err)
|
|
return
|
|
}
|
|
nextBatch.batchInputs = batchInputs
|
|
nextBatch.batch = batch
|
|
|
|
return
|
|
}
|
|
|
|
// Async processing of the next batch
|
|
func (s *Server) computeBatch(activeBatch batchState) {
|
|
if activeBatch.ctx == nil {
|
|
// Nothing to compute
|
|
return
|
|
}
|
|
defer activeBatch.ctx.Close()
|
|
|
|
// Wait until inputs are ready
|
|
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
|
<-activeBatch.inputsReadyCh
|
|
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
|
|
|
|
// Once we complete, signal the next batch of inputs are ready
|
|
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
|
defer func() {
|
|
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
|
|
activeBatch.outputsReadyCh <- struct{}{}
|
|
}()
|
|
|
|
s.mu.Lock()
|
|
|
|
// Gather the actual input token values now that they're ready
|
|
batchInputs := make([]int32, len(activeBatch.batchInputs))
|
|
for i := range batchInputs {
|
|
batchInputs[i] = activeBatch.batchInputs[i].Token
|
|
}
|
|
|
|
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
|
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
|
// decoded tokens.
|
|
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
|
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
|
for i, seq := range s.seqs {
|
|
iBatches[i] = -1
|
|
if seq == nil {
|
|
continue
|
|
}
|
|
// Skip over any newly added or skipped sequences
|
|
if activeBatch.seqs[i] == nil {
|
|
continue
|
|
}
|
|
|
|
// Detect if the sequence we're processing has already been completed and replaced
|
|
// with a new sequence
|
|
if seq != activeBatch.seqs[i] {
|
|
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
|
continue
|
|
}
|
|
|
|
// Pending inputs will actually be in the cache after we call Compute.
|
|
// However, we have already resolved any placeholder tokens.
|
|
//
|
|
// It's possible for incoming sequences to look at the values that we've
|
|
// added to the cache here and start relying on them before we've done
|
|
// the computation. This is OK as long as we ensure that this batch's
|
|
// computation happens before any future batch's and we never fail
|
|
// (unless we take down the whole runner).
|
|
if len(seq.pendingInputs) > 0 {
|
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
|
seq.pendingInputs = []*input.Input{}
|
|
}
|
|
|
|
// don't sample prompt processing
|
|
if len(seq.inputs) != 0 {
|
|
if !s.cache.enabled {
|
|
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
seq.numPredicted++
|
|
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
|
seq.inputs = []*input.Input{nextToken}
|
|
nextBatchTokens[i] = nextToken
|
|
iBatches[i] = seq.iBatch
|
|
}
|
|
|
|
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
|
s.mu.Unlock()
|
|
|
|
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
|
activeBatch.ctx.ComputeWithNotify(
|
|
func() {
|
|
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
|
activeBatch.computeStartedCh <- struct{}{}
|
|
},
|
|
activeBatch.modelOutput)
|
|
|
|
outputs := activeBatch.modelOutput.Floats()
|
|
|
|
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
|
|
for i, seq := range s.seqs {
|
|
if seq == nil || nextBatchTokens[i] == nil {
|
|
continue
|
|
}
|
|
|
|
if seq.numPredicted == 1 {
|
|
seq.startGenerationTime = time.Now()
|
|
}
|
|
|
|
// if done processing the prompt, generate an embedding and return
|
|
if seq.embeddingOnly {
|
|
seq.embedding <- outputs
|
|
s.removeSequence(i, llm.DoneReasonStop)
|
|
continue
|
|
}
|
|
|
|
// sample a token
|
|
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
|
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
|
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
|
if err != nil {
|
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
|
return
|
|
}
|
|
|
|
nextBatchTokens[i].Token = token
|
|
|
|
// if it's an end of sequence token, break
|
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
|
// TODO (jmorganca): we should send this back
|
|
// as it's important for the /api/generate context
|
|
// seq.responses <- piece
|
|
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
|
s.removeSequence(i, llm.DoneReasonStop)
|
|
continue
|
|
}
|
|
|
|
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
|
if err != nil {
|
|
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
|
return
|
|
}
|
|
|
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
|
sequence := strings.Join(seq.pendingResponses, "")
|
|
|
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
|
|
|
var tokenTruncated bool
|
|
origLen := len(seq.pendingResponses)
|
|
seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
|
|
newLen := len(seq.pendingResponses)
|
|
|
|
// Update the cache based on the tokens that will be returned:
|
|
// - We have 1 token more than is currently in the cache because
|
|
// the last one generated wasn't submitted to Decode
|
|
// - Remove any stop sequences that we stripped out
|
|
// - If truncateStop removed a portion of a token, drop that
|
|
// - As defense-in-depth, if truncatedToken didn't find a stop token
|
|
// remove the extra one that we added to the cache len
|
|
tokenLen := len(seq.cache.Inputs) + 1
|
|
tokenLen -= origLen - newLen
|
|
if tokenTruncated || origLen == newLen {
|
|
tokenLen--
|
|
}
|
|
|
|
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
|
|
|
s.removeSequence(i, llm.DoneReasonStop)
|
|
continue
|
|
}
|
|
|
|
if common.ContainsStopSuffix(sequence, seq.stop) {
|
|
continue
|
|
}
|
|
|
|
if common.IncompleteUnicode(sequence) {
|
|
continue
|
|
}
|
|
|
|
if !flushPending(seq) {
|
|
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
var req llm.CompletionRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Options == nil {
|
|
opts := api.DefaultOptions()
|
|
req.Options = &opts
|
|
}
|
|
|
|
// Set the headers to indicate streaming
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var grammar *sample.GrammarSampler
|
|
var err error
|
|
if req.Grammar != "" {
|
|
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
|
if err != nil {
|
|
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer grammar.Free()
|
|
}
|
|
|
|
sampler := sample.NewSampler(
|
|
req.Options.Temperature,
|
|
req.Options.TopK,
|
|
req.Options.TopP,
|
|
req.Options.MinP,
|
|
req.Options.Seed,
|
|
grammar,
|
|
)
|
|
|
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
|
numPredict: req.Options.NumPredict,
|
|
stop: req.Options.Stop,
|
|
numKeep: int32(req.Options.NumKeep),
|
|
sampler: sampler,
|
|
embedding: false,
|
|
})
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
|
if errors.Is(err, context.Canceled) {
|
|
slog.Info("aborting completion request due to client closing the connection")
|
|
} else {
|
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
found := false
|
|
for i, sq := range s.seqs {
|
|
if sq == nil {
|
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
s.seqsSem.Release(1)
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
s.seqs[i] = seq
|
|
s.cond.Signal()
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if !found {
|
|
s.seqsSem.Release(1)
|
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
close(seq.quit)
|
|
return
|
|
case content, ok := <-seq.responses:
|
|
if ok {
|
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
|
Content: content,
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
close(seq.quit)
|
|
return
|
|
}
|
|
|
|
flusher.Flush()
|
|
} else {
|
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
|
Done: true,
|
|
DoneReason: seq.doneReason,
|
|
PromptEvalCount: seq.numPromptInputs,
|
|
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
|
EvalCount: seq.numPredicted,
|
|
EvalDuration: time.Since(seq.startGenerationTime),
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
|
|
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
|
return
|
|
}
|
|
|
|
var req llm.EmbeddingRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
|
if errors.Is(err, context.Canceled) {
|
|
slog.Info("aborting embedding request due to client closing the connection")
|
|
} else {
|
|
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
found := false
|
|
for i, sq := range s.seqs {
|
|
if sq == nil {
|
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
s.seqsSem.Release(1)
|
|
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
s.seqs[i] = seq
|
|
s.cond.Signal()
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if !found {
|
|
s.seqsSem.Release(1)
|
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
|
Embedding: <-seq.embedding,
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
|
Status: s.status,
|
|
Progress: s.progress,
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (s *Server) reserveWorstCaseGraph() error {
|
|
ctx := s.model.Backend().NewContext()
|
|
defer ctx.Close()
|
|
|
|
var err error
|
|
inputs := make([]*input.Input, s.batchSize)
|
|
for i := range inputs {
|
|
inputs[i] = &input.Input{}
|
|
}
|
|
mmStore := newMultimodalStore()
|
|
|
|
// Multimodal strategy:
|
|
// - Encode a 2048x2048 image. This assumes that a single image of this
|
|
// size is sufficient to trigger the worst case. This is currently true
|
|
// because for existing models, only a single image fits in a batch.
|
|
// - Add the embedding to a full batch of tokens - this is necessary because
|
|
// the model may be looking for non-image data, such as <image> tags.
|
|
// - Run PostTokenize to execute any transformations between generated
|
|
// embeddings and what the forward pass expects.
|
|
// - The result may now be larger than a batch (images may not fit in a
|
|
// single batch), so trim based on what will fit and must be grouped together.
|
|
// - Fill out the rest of the space with text tokens.
|
|
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
|
|
mmCtx := s.model.Backend().NewContext()
|
|
defer mmCtx.Close()
|
|
|
|
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
|
|
var buf bytes.Buffer
|
|
bmp.Encode(&buf, img)
|
|
|
|
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
|
|
mmStore.addMultimodal(inputs[0].Multimodal)
|
|
|
|
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i, inp := range inputs {
|
|
minBatch := 1 + inp.SameBatch
|
|
if minBatch > s.batchSize {
|
|
inputs = inputs[i:min(i+minBatch, len(inputs))]
|
|
break
|
|
} else if i+minBatch > s.batchSize {
|
|
inputs = inputs[:i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(inputs) < s.batchSize {
|
|
newInputs := make([]*input.Input, s.batchSize)
|
|
copy(newInputs, inputs)
|
|
for i := len(inputs); i < s.batchSize; i++ {
|
|
newInputs[i] = &input.Input{}
|
|
}
|
|
inputs = newInputs
|
|
}
|
|
}
|
|
}
|
|
|
|
var batch input.Batch
|
|
|
|
batchInputs := make([]int32, len(inputs))
|
|
batch.Positions = make([]int32, len(inputs))
|
|
batch.Sequences = make([]int, len(inputs))
|
|
for i, inp := range inputs {
|
|
batchInputs[i] = inp.Token
|
|
if inp.Multimodal != nil {
|
|
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
|
|
}
|
|
|
|
batch.Positions[i] = int32(i)
|
|
}
|
|
|
|
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
|
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
|
|
|
cache := s.model.Config().Cache
|
|
if cache != nil {
|
|
err := cache.StartForward(ctx, batch, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
t, err := s.model.Forward(ctx, batch)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx.Forward(t).Reserve()
|
|
|
|
return nil
|
|
}
|
|
|
|
// allocModel pre-allocates the maximum needed memory for a model
|
|
// based on the given parameters
|
|
func (s *Server) allocModel(
|
|
mpath string,
|
|
params ml.BackendParams,
|
|
loraPath []string,
|
|
parallel int,
|
|
kvCacheType string,
|
|
kvSize int,
|
|
multiUserCache bool,
|
|
) (panicErr error) {
|
|
// Convert memory allocation panics to errors
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if err, ok := r.(error); ok {
|
|
var noMem ml.ErrNoMem
|
|
if errors.As(err, &noMem) {
|
|
panicErr = noMem
|
|
} else {
|
|
panic(r)
|
|
}
|
|
} else {
|
|
panic(r)
|
|
}
|
|
}
|
|
}()
|
|
|
|
var err error
|
|
s.model, err = model.New(mpath, params)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO(jessegross): LoRA loading
|
|
if len(loraPath) > 0 {
|
|
return errors.New("loras are not yet implemented")
|
|
}
|
|
|
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !s.cache.enabled && parallel > 1 {
|
|
parallel = 1
|
|
slog.Warn("model does not support caching, disabling parallel processing")
|
|
}
|
|
|
|
s.parallel = parallel
|
|
s.seqs = make([]*Sequence, s.parallel)
|
|
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
|
|
|
return s.reserveWorstCaseGraph()
|
|
}
|
|
|
|
// closeModel frees all memory associated with a model
|
|
func (s *Server) closeModel() {
|
|
s.cache.Close()
|
|
s.cache = nil
|
|
if s.model != nil {
|
|
s.model.Backend().Close()
|
|
s.model = nil
|
|
}
|
|
}
|
|
|
|
// loadModel loads the weights for a model. The memory must already
|
|
// have been allocated with allocModel
|
|
func (s *Server) loadModel() {
|
|
err := s.model.Backend().Load(context.TODO(),
|
|
func(progress float32) {
|
|
s.progress = progress
|
|
})
|
|
if err != nil {
|
|
panic(fmt.Errorf("failed to load model: %v", err))
|
|
}
|
|
|
|
s.status = llm.ServerStatusReady
|
|
s.ready.Done()
|
|
}
|
|
|
|
// load is the handler called by the Ollama server to process different
|
|
// load operations
|
|
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
|
s.loadMu.Lock()
|
|
defer s.loadMu.Unlock()
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
if s.status != llm.ServerStatusLaunched {
|
|
http.Error(w, "model already loaded", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var req llm.LoadRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
slog.Info("load", "request", req)
|
|
|
|
if req.Operation == llm.LoadOperationClose {
|
|
s.closeModel()
|
|
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
s.lastLoad.Operation = req.Operation
|
|
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
|
|
|
|
s.lastLoad = req
|
|
|
|
if loadModel {
|
|
s.closeModel()
|
|
|
|
params := ml.BackendParams{
|
|
AllocMemory: req.Operation != llm.LoadOperationFit,
|
|
NumThreads: req.NumThreads,
|
|
GPULayers: req.GPULayers,
|
|
FlashAttention: req.FlashAttention,
|
|
}
|
|
|
|
s.batchSize = req.BatchSize
|
|
|
|
err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
|
|
if err != nil {
|
|
s.closeModel()
|
|
|
|
var noMem ml.ErrNoMem
|
|
if errors.As(err, &noMem) {
|
|
resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
|
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
http.Error(w, fmt.Sprintf("failed to initialize model: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
mem := s.model.Backend().BackendMemory()
|
|
|
|
switch req.Operation {
|
|
case llm.LoadOperationFit:
|
|
// LoadOperationFit can't be used for anything else, so just close it
|
|
s.closeModel()
|
|
|
|
// LoadOperationAlloc should stay open for future operations
|
|
|
|
case llm.LoadOperationCommit:
|
|
s.status = llm.ServerStatusLoadingModel
|
|
go s.loadModel()
|
|
}
|
|
|
|
resp := llm.LoadResponse{Success: true, Memory: mem}
|
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// info is the handler called by the Ollama server to report information
|
|
// about the GPU devices in use by this runner
|
|
func (s *Server) info(w http.ResponseWriter, r *http.Request) {
|
|
s.loadMu.Lock()
|
|
defer s.loadMu.Unlock()
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
m := s.model
|
|
|
|
if m == nil {
|
|
startLoad := time.Now()
|
|
|
|
// Dummy load to get the backend wired up
|
|
f, err := os.CreateTemp("", "*.bin")
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer f.Close()
|
|
defer os.Remove(f.Name())
|
|
|
|
if err := ggml.WriteGGUF(f, ggml.KV{
|
|
"general.architecture": "llama",
|
|
"tokenizer.ggml.model": "gpt2",
|
|
}, nil); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
|
|
}
|
|
|
|
startDevices := time.Now()
|
|
infos := m.Backend().BackendDevices()
|
|
slog.Debug("gathering device infos took", "duration", time.Since(startDevices))
|
|
if err := json.NewEncoder(w).Encode(&infos); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func Execute(args []string) error {
|
|
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
|
mpath := fs.String("model", "", "Path to model binary file")
|
|
port := fs.Int("port", 8080, "Port to expose the server on")
|
|
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
|
|
|
fs.Usage = func() {
|
|
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
|
fs.PrintDefaults()
|
|
}
|
|
if err := fs.Parse(args); err != nil {
|
|
return err
|
|
}
|
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
|
slog.Info("starting ollama engine")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
server := &Server{
|
|
modelPath: *mpath,
|
|
status: llm.ServerStatusLaunched,
|
|
hardErrCh: make(chan error, 1),
|
|
}
|
|
|
|
server.cond = sync.NewCond(&server.mu)
|
|
server.ready.Add(1)
|
|
|
|
go server.run(ctx)
|
|
|
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
fmt.Println("Listen error:", err)
|
|
return err
|
|
}
|
|
defer listener.Close()
|
|
|
|
mux := http.NewServeMux()
|
|
// TODO: support embeddings
|
|
mux.HandleFunc("GET /info", server.info)
|
|
mux.HandleFunc("POST /load", server.load)
|
|
mux.HandleFunc("POST /embedding", server.embeddings)
|
|
mux.HandleFunc("POST /completion", server.completion)
|
|
mux.HandleFunc("GET /health", server.health)
|
|
|
|
httpServer := http.Server{
|
|
Handler: mux,
|
|
}
|
|
|
|
log.Println("Server listening on", addr)
|
|
if err := httpServer.Serve(listener); err != nil {
|
|
log.Fatal("server error:", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|