mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			387 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			387 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
| // Package api implements the client-side API for code wishing to interact
 | |
| // with the ollama service. The methods of the [Client] type correspond to
 | |
| // the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
 | |
| //
 | |
| // The ollama command-line client itself uses this package to interact with
 | |
| // the backend service.
 | |
| package api
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"runtime"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/ollama/ollama/format"
 | |
| 	"github.com/ollama/ollama/version"
 | |
| )
 | |
| 
 | |
| // Client encapsulates client state for interacting with the ollama
 | |
| // service. Use [ClientFromEnvironment] to create new Clients.
 | |
| type Client struct {
 | |
| 	base *url.URL
 | |
| 	http *http.Client
 | |
| }
 | |
| 
 | |
| func checkError(resp *http.Response, body []byte) error {
 | |
| 	if resp.StatusCode < http.StatusBadRequest {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	apiError := StatusError{StatusCode: resp.StatusCode}
 | |
| 
 | |
| 	err := json.Unmarshal(body, &apiError)
 | |
| 	if err != nil {
 | |
| 		// Use the full body as the message if we fail to decode a response.
 | |
| 		apiError.ErrorMessage = string(body)
 | |
| 	}
 | |
| 
 | |
| 	return apiError
 | |
| }
 | |
| 
 | |
| // ClientFromEnvironment creates a new [Client] using configuration from the
 | |
| // environment variable OLLAMA_HOST, which points to the network host and
 | |
| // port on which the ollama service is listenting. The format of this variable
 | |
| // is:
 | |
| //
 | |
| //	<scheme>://<host>:<port>
 | |
| //
 | |
| // If the variable is not specified, a default ollama host and port will be
 | |
| // used.
 | |
| func ClientFromEnvironment() (*Client, error) {
 | |
| 	ollamaHost, err := GetOllamaHost()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &Client{
 | |
| 		base: &url.URL{
 | |
| 			Scheme: ollamaHost.Scheme,
 | |
| 			Host:   net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
 | |
| 		},
 | |
| 		http: http.DefaultClient,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| type OllamaHost struct {
 | |
| 	Scheme string
 | |
| 	Host   string
 | |
| 	Port   string
 | |
| }
 | |
| 
 | |
| func GetOllamaHost() (OllamaHost, error) {
 | |
| 	defaultPort := "11434"
 | |
| 
 | |
| 	hostVar := os.Getenv("OLLAMA_HOST")
 | |
| 	hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
 | |
| 
 | |
| 	scheme, hostport, ok := strings.Cut(hostVar, "://")
 | |
| 	switch {
 | |
| 	case !ok:
 | |
| 		scheme, hostport = "http", hostVar
 | |
| 	case scheme == "http":
 | |
| 		defaultPort = "80"
 | |
| 	case scheme == "https":
 | |
| 		defaultPort = "443"
 | |
| 	}
 | |
| 
 | |
| 	// trim trailing slashes
 | |
| 	hostport = strings.TrimRight(hostport, "/")
 | |
| 
 | |
| 	host, port, err := net.SplitHostPort(hostport)
 | |
| 	if err != nil {
 | |
| 		host, port = "127.0.0.1", defaultPort
 | |
| 		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
 | |
| 			host = ip.String()
 | |
| 		} else if hostport != "" {
 | |
| 			host = hostport
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
 | |
| 		return OllamaHost{}, ErrInvalidHostPort
 | |
| 	}
 | |
| 
 | |
| 	return OllamaHost{
 | |
| 		Scheme: scheme,
 | |
| 		Host:   host,
 | |
| 		Port:   port,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func NewClient(base *url.URL, http *http.Client) *Client {
 | |
| 	return &Client{
 | |
| 		base: base,
 | |
| 		http: http,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 | |
| 	var reqBody io.Reader
 | |
| 	var data []byte
 | |
| 	var err error
 | |
| 
 | |
| 	switch reqData := reqData.(type) {
 | |
| 	case io.Reader:
 | |
| 		// reqData is already an io.Reader
 | |
| 		reqBody = reqData
 | |
| 	case nil:
 | |
| 		// noop
 | |
| 	default:
 | |
| 		data, err = json.Marshal(reqData)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		reqBody = bytes.NewReader(data)
 | |
| 	}
 | |
| 
 | |
| 	requestURL := c.base.JoinPath(path)
 | |
| 	request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	request.Header.Set("Content-Type", "application/json")
 | |
| 	request.Header.Set("Accept", "application/json")
 | |
| 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 | |
| 
 | |
| 	respObj, err := c.http.Do(request)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer respObj.Body.Close()
 | |
| 
 | |
| 	respBody, err := io.ReadAll(respObj.Body)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if err := checkError(respObj, respBody); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if len(respBody) > 0 && respData != nil {
 | |
| 		if err := json.Unmarshal(respBody, respData); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| const maxBufferSize = 512 * format.KiloByte
 | |
| 
 | |
| func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 | |
| 	var buf *bytes.Buffer
 | |
| 	if data != nil {
 | |
| 		bts, err := json.Marshal(data)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		buf = bytes.NewBuffer(bts)
 | |
| 	}
 | |
| 
 | |
| 	requestURL := c.base.JoinPath(path)
 | |
| 	request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	request.Header.Set("Content-Type", "application/json")
 | |
| 	request.Header.Set("Accept", "application/x-ndjson")
 | |
| 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 | |
| 
 | |
| 	response, err := c.http.Do(request)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer response.Body.Close()
 | |
| 
 | |
| 	scanner := bufio.NewScanner(response.Body)
 | |
| 	// increase the buffer size to avoid running out of space
 | |
| 	scanBuf := make([]byte, 0, maxBufferSize)
 | |
| 	scanner.Buffer(scanBuf, maxBufferSize)
 | |
| 	for scanner.Scan() {
 | |
| 		var errorResponse struct {
 | |
| 			Error string `json:"error,omitempty"`
 | |
| 		}
 | |
| 
 | |
| 		bts := scanner.Bytes()
 | |
| 		if err := json.Unmarshal(bts, &errorResponse); err != nil {
 | |
| 			return fmt.Errorf("unmarshal: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		if errorResponse.Error != "" {
 | |
| 			return fmt.Errorf(errorResponse.Error)
 | |
| 		}
 | |
| 
 | |
| 		if response.StatusCode >= http.StatusBadRequest {
 | |
| 			return StatusError{
 | |
| 				StatusCode:   response.StatusCode,
 | |
| 				Status:       response.Status,
 | |
| 				ErrorMessage: errorResponse.Error,
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if err := fn(bts); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // GenerateResponseFunc is a function that [Client.Generate] invokes every time
 | |
| // a response is received from the service. If this function returns an error,
 | |
| // [Client.Generate] will stop generating and return this error.
 | |
| type GenerateResponseFunc func(GenerateResponse) error
 | |
| 
 | |
| // Generate generates a response for a given prompt. The req parameter should
 | |
| // be populated with prompt details. fn is called for each response (there may
 | |
| // be multiple responses, e.g. in case streaming is enabled).
 | |
| func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
 | |
| 	return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
 | |
| 		var resp GenerateResponse
 | |
| 		if err := json.Unmarshal(bts, &resp); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(resp)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // ChatResponseFunc is a function that [Client.Chat] invokes every time
 | |
| // a response is received from the service. If this function returns an error,
 | |
| // [Client.Chat] will stop generating and return this error.
 | |
| type ChatResponseFunc func(ChatResponse) error
 | |
| 
 | |
| // Chat generates the next message in a chat. [ChatRequest] may contain a
 | |
| // sequence of messages which can be used to maintain chat history with a model.
 | |
| // fn is called for each response (there may be multiple responses, e.g. if case
 | |
| // streaming is enabled).
 | |
| func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
 | |
| 	return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
 | |
| 		var resp ChatResponse
 | |
| 		if err := json.Unmarshal(bts, &resp); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(resp)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // PullProgressFunc is a function that [Client.Pull] invokes every time there
 | |
| // is progress with a "pull" request sent to the service. If this function
 | |
| // returns an error, [Client.Pull] will stop the process and return this error.
 | |
| type PullProgressFunc func(ProgressResponse) error
 | |
| 
 | |
| // Pull downloads a model from the ollama library. fn is called each time
 | |
| // progress is made on the request and can be used to display a progress bar,
 | |
| // etc.
 | |
| func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
 | |
| 	return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
 | |
| 		var resp ProgressResponse
 | |
| 		if err := json.Unmarshal(bts, &resp); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(resp)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| type PushProgressFunc func(ProgressResponse) error
 | |
| 
 | |
| func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
 | |
| 	return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
 | |
| 		var resp ProgressResponse
 | |
| 		if err := json.Unmarshal(bts, &resp); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(resp)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| type CreateProgressFunc func(ProgressResponse) error
 | |
| 
 | |
| func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
 | |
| 	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
 | |
| 		var resp ProgressResponse
 | |
| 		if err := json.Unmarshal(bts, &resp); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(resp)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 | |
| 	var lr ListResponse
 | |
| 	if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return &lr, nil
 | |
| }
 | |
| 
 | |
| func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 | |
| 	if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
 | |
| 	if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
 | |
| 	var resp ShowResponse
 | |
| 	if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return &resp, nil
 | |
| }
 | |
| 
 | |
| func (c *Client) Heartbeat(ctx context.Context) error {
 | |
| 	if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 | |
| 	var resp EmbeddingResponse
 | |
| 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return &resp, nil
 | |
| }
 | |
| 
 | |
| func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
 | |
| 	return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
 | |
| }
 | |
| 
 | |
| func (c *Client) Version(ctx context.Context) (string, error) {
 | |
| 	var version struct {
 | |
| 		Version string `json:"version"`
 | |
| 	}
 | |
| 
 | |
| 	if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return version.Version, nil
 | |
| }
 |