mirror of https://github.com/ollama/ollama.git
				
				
				
			return error in generate response
This commit is contained in:
		
							parent
							
								
									2d49197b3b
								
							
						
					
					
						commit
						edba935d67
					
				|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | @ -25,6 +26,18 @@ func NewClient(hosts ...string) *Client { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func StatusError(status int, message ...string) error { | ||||||
|  | 	if status < 400 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(message) > 0 && len(message[0]) > 0 { | ||||||
|  | 		return fmt.Errorf("%d %s: %s", status, http.StatusText(status), message[0]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return fmt.Errorf("%d %s", status, http.StatusText(status)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type options struct { | type options struct { | ||||||
| 	requestBody  io.Reader | 	requestBody  io.Reader | ||||||
| 	responseFunc func(bts []byte) error | 	responseFunc func(bts []byte) error | ||||||
|  | @ -70,7 +83,20 @@ func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*o | ||||||
| 	if opts.responseFunc != nil { | 	if opts.responseFunc != nil { | ||||||
| 		scanner := bufio.NewScanner(response.Body) | 		scanner := bufio.NewScanner(response.Body) | ||||||
| 		for scanner.Scan() { | 		for scanner.Scan() { | ||||||
| 			if err := opts.responseFunc(scanner.Bytes()); err != nil { | 			var errorResponse struct { | ||||||
|  | 				Error string `json:"error"` | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			bts := scanner.Bytes() | ||||||
|  | 			if err := json.Unmarshal(bts, &errorResponse); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if err := StatusError(response.StatusCode, errorResponse.Error); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if err := opts.responseFunc(bts); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ func (e Error) Error() string { | ||||||
| 	if e.Message == "" { | 	if e.Message == "" { | ||||||
| 		return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code)))) | 		return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code)))) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return e.Message | 	return e.Message | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -100,14 +100,19 @@ func generate(model, prompt string) error { | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 
 | 
 | ||||||
| 		client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { | 		request := api.GenerateRequest{Model: model, Prompt: prompt} | ||||||
|  | 		fn := func(resp api.GenerateResponse) error { | ||||||
| 			if !spinner.IsFinished() { | 			if !spinner.IsFinished() { | ||||||
| 				spinner.Finish() | 				spinner.Finish() | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			fmt.Print(resp.Response) | 			fmt.Print(resp.Response) | ||||||
| 			return nil | 			return nil | ||||||
| 		}) | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := client.Generate(context.Background(), &request, fn); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		fmt.Println() | 		fmt.Println() | ||||||
| 		fmt.Println() | 		fmt.Println() | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"embed" | 	"embed" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"log" | 	"log" | ||||||
| 	"math" | 	"math" | ||||||
|  | @ -46,7 +45,7 @@ func generate(c *gin.Context) { | ||||||
| 		req.PredictOptions = &api.DefaultPredictOptions | 		req.PredictOptions = &api.DefaultPredictOptions | ||||||
| 	} | 	} | ||||||
| 	if err := c.ShouldBindJSON(&req); err != nil { | 	if err := c.ShouldBindJSON(&req); err != nil { | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) | 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -66,7 +65,7 @@ func generate(c *gin.Context) { | ||||||
| 
 | 
 | ||||||
| 	model, err := llama.New(req.Model, modelOpts) | 	model, err := llama.New(req.Model, modelOpts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		fmt.Println("Loading the model failed:", err.Error()) | 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	defer model.Free() | 	defer model.Free() | ||||||
|  | @ -80,7 +79,7 @@ func generate(c *gin.Context) { | ||||||
| 	if template := templates.Lookup(match); template != nil { | 	if template := templates.Lookup(match); template != nil { | ||||||
| 		var sb strings.Builder | 		var sb strings.Builder | ||||||
| 		if err := template.Execute(&sb, req); err != nil { | 		if err := template.Execute(&sb, req); err != nil { | ||||||
| 			fmt.Println("Prompt template failed:", err.Error()) | 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue