mirror of https://github.com/ollama/ollama.git
				
				
				
			return error in generate response
This commit is contained in:
		
							parent
							
								
									2d49197b3b
								
							
						
					
					
						commit
						edba935d67
					
				|  | @ -5,6 +5,7 @@ import ( | |||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
|  | @ -25,6 +26,18 @@ func NewClient(hosts ...string) *Client { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func StatusError(status int, message ...string) error { | ||||
| 	if status < 400 { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	if len(message) > 0 && len(message[0]) > 0 { | ||||
| 		return fmt.Errorf("%d %s: %s", status, http.StatusText(status), message[0]) | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Errorf("%d %s", status, http.StatusText(status)) | ||||
| } | ||||
| 
 | ||||
| type options struct { | ||||
| 	requestBody  io.Reader | ||||
| 	responseFunc func(bts []byte) error | ||||
|  | @ -70,7 +83,20 @@ func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*o | |||
| 	if opts.responseFunc != nil { | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		for scanner.Scan() { | ||||
| 			if err := opts.responseFunc(scanner.Bytes()); err != nil { | ||||
| 			var errorResponse struct { | ||||
| 				Error string `json:"error"` | ||||
| 			} | ||||
| 
 | ||||
| 			bts := scanner.Bytes() | ||||
| 			if err := json.Unmarshal(bts, &errorResponse); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			if err := StatusError(response.StatusCode, errorResponse.Error); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			if err := opts.responseFunc(bts); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ func (e Error) Error() string { | |||
| 	if e.Message == "" { | ||||
| 		return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code)))) | ||||
| 	} | ||||
| 
 | ||||
| 	return e.Message | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -100,14 +100,19 @@ func generate(model, prompt string) error { | |||
| 			} | ||||
| 		}() | ||||
| 
 | ||||
| 		client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { | ||||
| 		request := api.GenerateRequest{Model: model, Prompt: prompt} | ||||
| 		fn := func(resp api.GenerateResponse) error { | ||||
| 			if !spinner.IsFinished() { | ||||
| 				spinner.Finish() | ||||
| 			} | ||||
| 
 | ||||
| 			fmt.Print(resp.Response) | ||||
| 			return nil | ||||
| 		}) | ||||
| 		} | ||||
| 
 | ||||
| 		if err := client.Generate(context.Background(), &request, fn); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		fmt.Println() | ||||
| 		fmt.Println() | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ import ( | |||
| 	"embed" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"math" | ||||
|  | @ -46,7 +45,7 @@ func generate(c *gin.Context) { | |||
| 		req.PredictOptions = &api.DefaultPredictOptions | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&req); err != nil { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -66,7 +65,7 @@ func generate(c *gin.Context) { | |||
| 
 | ||||
| 	model, err := llama.New(req.Model, modelOpts) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Loading the model failed:", err.Error()) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 	defer model.Free() | ||||
|  | @ -80,7 +79,7 @@ func generate(c *gin.Context) { | |||
| 	if template := templates.Lookup(match); template != nil { | ||||
| 		var sb strings.Builder | ||||
| 		if err := template.Execute(&sb, req); err != nil { | ||||
| 			fmt.Println("Prompt template failed:", err.Error()) | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue