| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | package openai | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"bytes" | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 	"encoding/base64" | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	"encoding/json" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"net/http/httptest" | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 	"reflect" | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	"testing" | 
					
						
							|  |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/gin-gonic/gin" | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/ollama/ollama/api" | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | const ( | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 	prefix = `data:image/jpeg;base64,` | 
					
						
							|  |  |  | 	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` | 
					
						
							| 
									
										
										
										
											2024-08-02 05:52:15 +08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | var False = false | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { | 
					
						
							|  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  | 		bodyBytes, _ := io.ReadAll(c.Request.Body) | 
					
						
							|  |  |  | 		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) | 
					
						
							|  |  |  | 		err := json.Unmarshal(bodyBytes, capturedRequest) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		c.Next() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestChatMiddleware(t *testing.T) { | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	type testCase struct { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		name string | 
					
						
							|  |  |  | 		body string | 
					
						
							|  |  |  | 		req  api.ChatRequest | 
					
						
							|  |  |  | 		err  ErrorResponse | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	var capturedRequest *api.ChatRequest | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 	testCases := []testCase{ | 
					
						
							|  |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "chat handler", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model", | 
					
						
							|  |  |  | 				"messages": [ | 
					
						
							|  |  |  | 					{"role": "user", "content": "Hello"} | 
					
						
							|  |  |  | 				] | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			req: api.ChatRequest{ | 
					
						
							|  |  |  | 				Model: "test-model", | 
					
						
							|  |  |  | 				Messages: []api.Message{ | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						Role:    "user", | 
					
						
							|  |  |  | 						Content: "Hello", | 
					
						
							|  |  |  | 					}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				Options: map[string]any{ | 
					
						
							|  |  |  | 					"temperature": 1.0, | 
					
						
							|  |  |  | 					"top_p":       1.0, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				Stream: &False, | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 			}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "chat handler with image content", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model", | 
					
						
							|  |  |  | 				"messages": [ | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						"role": "user", | 
					
						
							|  |  |  | 						"content": [ | 
					
						
							|  |  |  | 							{ | 
					
						
							|  |  |  | 								"type": "text", | 
					
						
							|  |  |  | 								"text": "Hello" | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 							}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 							{ | 
					
						
							|  |  |  | 								"type": "image_url", | 
					
						
							|  |  |  | 								"image_url": { | 
					
						
							|  |  |  | 									"url": "` + prefix + image + `" | 
					
						
							|  |  |  | 								} | 
					
						
							|  |  |  | 							} | 
					
						
							|  |  |  | 						] | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				] | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			req: api.ChatRequest{ | 
					
						
							|  |  |  | 				Model: "test-model", | 
					
						
							|  |  |  | 				Messages: []api.Message{ | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						Role:    "user", | 
					
						
							|  |  |  | 						Content: "Hello", | 
					
						
							|  |  |  | 					}, | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						Role: "user", | 
					
						
							|  |  |  | 						Images: []api.ImageData{ | 
					
						
							|  |  |  | 							func() []byte { | 
					
						
							|  |  |  | 								img, _ := base64.StdEncoding.DecodeString(image) | 
					
						
							|  |  |  | 								return img | 
					
						
							|  |  |  | 							}(), | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 						}, | 
					
						
							|  |  |  | 					}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 				}, | 
					
						
							|  |  |  | 				Options: map[string]any{ | 
					
						
							|  |  |  | 					"temperature": 1.0, | 
					
						
							|  |  |  | 					"top_p":       1.0, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				Stream: &False, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "chat handler with tools", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model", | 
					
						
							|  |  |  | 				"messages": [ | 
					
						
							|  |  |  | 					{"role": "user", "content": "What's the weather like in Paris Today?"}, | 
					
						
							|  |  |  | 					{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} | 
					
						
							|  |  |  | 				] | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			req: api.ChatRequest{ | 
					
						
							|  |  |  | 				Model: "test-model", | 
					
						
							|  |  |  | 				Messages: []api.Message{ | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						Role:    "user", | 
					
						
							|  |  |  | 						Content: "What's the weather like in Paris Today?", | 
					
						
							|  |  |  | 					}, | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						Role: "assistant", | 
					
						
							|  |  |  | 						ToolCalls: []api.ToolCall{ | 
					
						
							|  |  |  | 							{ | 
					
						
							|  |  |  | 								Function: api.ToolCallFunction{ | 
					
						
							|  |  |  | 									Name: "get_current_weather", | 
					
						
							|  |  |  | 									Arguments: map[string]interface{}{ | 
					
						
							|  |  |  | 										"location": "Paris, France", | 
					
						
							|  |  |  | 										"format":   "celsius", | 
					
						
							|  |  |  | 									}, | 
					
						
							|  |  |  | 								}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 							}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 						}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 					}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 				}, | 
					
						
							|  |  |  | 				Options: map[string]any{ | 
					
						
							|  |  |  | 					"temperature": 1.0, | 
					
						
							|  |  |  | 					"top_p":       1.0, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				Stream: &False, | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		{ | 
					
						
							|  |  |  | 			name: "chat handler error forwarding", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model", | 
					
						
							|  |  |  | 				"messages": [ | 
					
						
							|  |  |  | 					{"role": "user", "content": 2} | 
					
						
							|  |  |  | 				] | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			err: ErrorResponse{ | 
					
						
							|  |  |  | 				Error: Error{ | 
					
						
							|  |  |  | 					Message: "invalid message content type: float64", | 
					
						
							|  |  |  | 					Type:    "invalid_request_error", | 
					
						
							|  |  |  | 				}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	endpoint := func(c *gin.Context) { | 
					
						
							|  |  |  | 		c.Status(http.StatusOK) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	gin.SetMode(gin.TestMode) | 
					
						
							|  |  |  | 	router := gin.New() | 
					
						
							|  |  |  | 	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) | 
					
						
							|  |  |  | 	router.Handle(http.MethodPost, "/api/chat", endpoint) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, tc := range testCases { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		t.Run(tc.name, func(t *testing.T) { | 
					
						
							|  |  |  | 			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) | 
					
						
							|  |  |  | 			req.Header.Set("Content-Type", "application/json") | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			resp := httptest.NewRecorder() | 
					
						
							|  |  |  | 			router.ServeHTTP(resp, req) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			var errResp ErrorResponse | 
					
						
							|  |  |  | 			if resp.Code != http.StatusOK { | 
					
						
							|  |  |  | 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { | 
					
						
							|  |  |  | 					t.Fatal(err) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { | 
					
						
							|  |  |  | 				t.Fatal("requests did not match") | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			if !reflect.DeepEqual(tc.err, errResp) { | 
					
						
							|  |  |  | 				t.Fatal("errors did not match") | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 			capturedRequest = nil | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestCompletionsMiddleware(t *testing.T) { | 
					
						
							|  |  |  | 	type testCase struct { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		name string | 
					
						
							|  |  |  | 		body string | 
					
						
							|  |  |  | 		req  api.GenerateRequest | 
					
						
							|  |  |  | 		err  ErrorResponse | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var capturedRequest *api.GenerateRequest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	testCases := []testCase{ | 
					
						
							|  |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "completions handler", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model", | 
					
						
							|  |  |  | 				"prompt": "Hello", | 
					
						
							|  |  |  | 				"temperature": 0.8, | 
					
						
							|  |  |  | 				"stop": ["\n", "stop"], | 
					
						
							|  |  |  | 				"suffix": "suffix" | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			req: api.GenerateRequest{ | 
					
						
							|  |  |  | 				Model:  "test-model", | 
					
						
							|  |  |  | 				Prompt: "Hello", | 
					
						
							|  |  |  | 				Options: map[string]any{ | 
					
						
							|  |  |  | 					"frequency_penalty": 0.0, | 
					
						
							|  |  |  | 					"presence_penalty":  0.0, | 
					
						
							|  |  |  | 					"temperature":       1.6, | 
					
						
							|  |  |  | 					"top_p":             1.0, | 
					
						
							|  |  |  | 					"stop":              []any{"\n", "stop"}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				Suffix: "suffix", | 
					
						
							|  |  |  | 				Stream: &False, | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "completions handler error forwarding", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model", | 
					
						
							|  |  |  | 				"prompt": "Hello", | 
					
						
							|  |  |  | 				"temperature": null, | 
					
						
							|  |  |  | 				"stop": [1, 2], | 
					
						
							|  |  |  | 				"suffix": "suffix" | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			err: ErrorResponse{ | 
					
						
							|  |  |  | 				Error: Error{ | 
					
						
							|  |  |  | 					Message: "invalid type for 'stop' field: float64", | 
					
						
							|  |  |  | 					Type:    "invalid_request_error", | 
					
						
							|  |  |  | 				}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	endpoint := func(c *gin.Context) { | 
					
						
							|  |  |  | 		c.Status(http.StatusOK) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	gin.SetMode(gin.TestMode) | 
					
						
							|  |  |  | 	router := gin.New() | 
					
						
							|  |  |  | 	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) | 
					
						
							|  |  |  | 	router.Handle(http.MethodPost, "/api/generate", endpoint) | 
					
						
							| 
									
										
										
										
											2024-07-14 13:07:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	for _, tc := range testCases { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		t.Run(tc.name, func(t *testing.T) { | 
					
						
							|  |  |  | 			req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) | 
					
						
							|  |  |  | 			req.Header.Set("Content-Type", "application/json") | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			resp := httptest.NewRecorder() | 
					
						
							|  |  |  | 			router.ServeHTTP(resp, req) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			var errResp ErrorResponse | 
					
						
							|  |  |  | 			if resp.Code != http.StatusOK { | 
					
						
							|  |  |  | 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { | 
					
						
							|  |  |  | 					t.Fatal(err) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { | 
					
						
							|  |  |  | 				t.Fatal("requests did not match") | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if !reflect.DeepEqual(tc.err, errResp) { | 
					
						
							|  |  |  | 				t.Fatal("errors did not match") | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			capturedRequest = nil | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestEmbeddingsMiddleware(t *testing.T) { | 
					
						
							|  |  |  | 	type testCase struct { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		name string | 
					
						
							|  |  |  | 		body string | 
					
						
							|  |  |  | 		req  api.EmbedRequest | 
					
						
							|  |  |  | 		err  ErrorResponse | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var capturedRequest *api.EmbedRequest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	testCases := []testCase{ | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "embed handler single input", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"input": "Hello", | 
					
						
							|  |  |  | 				"model": "test-model" | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			req: api.EmbedRequest{ | 
					
						
							|  |  |  | 				Input: "Hello", | 
					
						
							|  |  |  | 				Model: "test-model", | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "embed handler batch input", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"input": ["Hello", "World"], | 
					
						
							|  |  |  | 				"model": "test-model" | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			req: api.EmbedRequest{ | 
					
						
							|  |  |  | 				Input: []any{"Hello", "World"}, | 
					
						
							|  |  |  | 				Model: "test-model", | 
					
						
							| 
									
										
										
										
											2024-07-17 04:36:08 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "embed handler error forwarding", | 
					
						
							|  |  |  | 			body: `{ | 
					
						
							|  |  |  | 				"model": "test-model" | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 			err: ErrorResponse{ | 
					
						
							|  |  |  | 				Error: Error{ | 
					
						
							|  |  |  | 					Message: "invalid input", | 
					
						
							|  |  |  | 					Type:    "invalid_request_error", | 
					
						
							|  |  |  | 				}, | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 	endpoint := func(c *gin.Context) { | 
					
						
							|  |  |  | 		c.Status(http.StatusOK) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 	gin.SetMode(gin.TestMode) | 
					
						
							|  |  |  | 	router := gin.New() | 
					
						
							|  |  |  | 	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) | 
					
						
							|  |  |  | 	router.Handle(http.MethodPost, "/api/embed", endpoint) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 	for _, tc := range testCases { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		t.Run(tc.name, func(t *testing.T) { | 
					
						
							|  |  |  | 			req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) | 
					
						
							|  |  |  | 			req.Header.Set("Content-Type", "application/json") | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 			resp := httptest.NewRecorder() | 
					
						
							|  |  |  | 			router.ServeHTTP(resp, req) | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			var errResp ErrorResponse | 
					
						
							|  |  |  | 			if resp.Code != http.StatusOK { | 
					
						
							|  |  |  | 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { | 
					
						
							|  |  |  | 					t.Fatal(err) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { | 
					
						
							|  |  |  | 				t.Fatal("requests did not match") | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if !reflect.DeepEqual(tc.err, errResp) { | 
					
						
							|  |  |  | 				t.Fatal("errors did not match") | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			capturedRequest = nil | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2024-07-03 07:01:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | func TestListMiddleware(t *testing.T) { | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 	type testCase struct { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		name     string | 
					
						
							|  |  |  | 		endpoint func(c *gin.Context) | 
					
						
							|  |  |  | 		resp     string | 
					
						
							| 
									
										
										
										
											2024-07-10 04:48:31 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	testCases := []testCase{ | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "list handler", | 
					
						
							|  |  |  | 			endpoint: func(c *gin.Context) { | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 				c.JSON(http.StatusOK, api.ListResponse{ | 
					
						
							|  |  |  | 					Models: []api.ListModelResponse{ | 
					
						
							|  |  |  | 						{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 							Name:       "test-model", | 
					
						
							|  |  |  | 							ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 						}, | 
					
						
							|  |  |  | 					}, | 
					
						
							|  |  |  | 				}) | 
					
						
							|  |  |  | 			}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			resp: `{ | 
					
						
							|  |  |  | 				"object": "list", | 
					
						
							|  |  |  | 				"data": [ | 
					
						
							|  |  |  | 					{ | 
					
						
							|  |  |  | 						"id": "test-model", | 
					
						
							|  |  |  | 						"object": "model", | 
					
						
							|  |  |  | 						"created": 1686935002, | 
					
						
							|  |  |  | 						"owned_by": "library" | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				] | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			name: "list handler empty output", | 
					
						
							|  |  |  | 			endpoint: func(c *gin.Context) { | 
					
						
							|  |  |  | 				c.JSON(http.StatusOK, api.ListResponse{}) | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			resp: `{ | 
					
						
							|  |  |  | 				"object": "list", | 
					
						
							|  |  |  | 				"data": null | 
					
						
							|  |  |  | 			}`, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 	gin.SetMode(gin.TestMode) | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 	for _, tc := range testCases { | 
					
						
							|  |  |  | 		router := gin.New() | 
					
						
							|  |  |  | 		router.Use(ListMiddleware()) | 
					
						
							|  |  |  | 		router.Handle(http.MethodGet, "/api/tags", tc.endpoint) | 
					
						
							|  |  |  | 		req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		resp := httptest.NewRecorder() | 
					
						
							|  |  |  | 		router.ServeHTTP(resp, req) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		var expected, actual map[string]any | 
					
						
							|  |  |  | 		err := json.Unmarshal([]byte(tc.resp), &expected) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			t.Fatalf("failed to unmarshal expected response: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		err = json.Unmarshal(resp.Body.Bytes(), &actual) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			t.Fatalf("failed to unmarshal actual response: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if !reflect.DeepEqual(expected, actual) { | 
					
						
							|  |  |  | 			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestRetrieveMiddleware(t *testing.T) { | 
					
						
							|  |  |  | 	type testCase struct { | 
					
						
							|  |  |  | 		name     string | 
					
						
							|  |  |  | 		endpoint func(c *gin.Context) | 
					
						
							|  |  |  | 		resp     string | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	testCases := []testCase{ | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			name: "retrieve handler", | 
					
						
							|  |  |  | 			endpoint: func(c *gin.Context) { | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 				c.JSON(http.StatusOK, api.ShowResponse{ | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 					ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 				}) | 
					
						
							|  |  |  | 			}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			resp: `{ | 
					
						
							|  |  |  | 				"id":"test-model", | 
					
						
							|  |  |  | 				"object":"model", | 
					
						
							|  |  |  | 				"created":1686935002, | 
					
						
							|  |  |  | 				"owned_by":"library"} | 
					
						
							|  |  |  | 			`, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			name: "retrieve handler error forwarding", | 
					
						
							|  |  |  | 			endpoint: func(c *gin.Context) { | 
					
						
							|  |  |  | 				c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 			}, | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 			resp: `{ | 
					
						
							|  |  |  | 				"error": { | 
					
						
							|  |  |  | 				  "code": null, | 
					
						
							|  |  |  | 				  "message": "model not found", | 
					
						
							|  |  |  | 				  "param": null, | 
					
						
							|  |  |  | 				  "type": "api_error" | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			}`, | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	gin.SetMode(gin.TestMode) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, tc := range testCases { | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		router := gin.New() | 
					
						
							|  |  |  | 		router.Use(RetrieveMiddleware()) | 
					
						
							|  |  |  | 		router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) | 
					
						
							|  |  |  | 		req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		resp := httptest.NewRecorder() | 
					
						
							|  |  |  | 		router.ServeHTTP(resp, req) | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		var expected, actual map[string]any | 
					
						
							|  |  |  | 		err := json.Unmarshal([]byte(tc.resp), &expected) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			t.Fatalf("failed to unmarshal expected response: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2024-07-20 02:37:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 01:33:34 +08:00
										 |  |  | 		err = json.Unmarshal(resp.Body.Bytes(), &actual) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			t.Fatalf("failed to unmarshal actual response: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if !reflect.DeepEqual(expected, actual) { | 
					
						
							|  |  |  | 			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2024-07-03 02:50:56 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } |