mirror of https://github.com/ollama/ollama.git
				
				
				
			Add unit test of API routes (#1528)
This commit is contained in:
		
							parent
							
								
									6e16098a60
								
							
						
					
					
						commit
						630518f0d9
					
				|  | @ -1035,12 +1035,7 @@ func RunServer(cmd *cobra.Command, _ []string) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	var origins []string | ||||
| 	if o := os.Getenv("OLLAMA_ORIGINS"); o != "" { | ||||
| 		origins = strings.Split(o, ",") | ||||
| 	} | ||||
| 
 | ||||
| 	return server.Serve(ln, origins) | ||||
| 	return server.Serve(ln) | ||||
| } | ||||
| 
 | ||||
| func getImageData(filePath string) ([]byte, error) { | ||||
|  |  | |||
							
								
								
									
										3
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										3
									
								
								go.mod
								
								
								
								
							|  | @ -7,11 +7,14 @@ require ( | |||
| 	github.com/gin-gonic/gin v1.9.1 | ||||
| 	github.com/olekukonko/tablewriter v0.0.5 | ||||
| 	github.com/spf13/cobra v1.7.0 | ||||
| 	github.com/stretchr/testify v1.8.3 | ||||
| 	golang.org/x/sync v0.3.0 | ||||
| ) | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/mattn/go-runewidth v0.0.14 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/rivo/uniseg v0.2.0 // indirect | ||||
| ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -32,6 +32,10 @@ import ( | |||
| 
 | ||||
| var mode string = gin.DebugMode | ||||
| 
 | ||||
| type Server struct { | ||||
| 	WorkDir string | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	switch mode { | ||||
| 	case gin.DebugMode: | ||||
|  | @ -800,27 +804,27 @@ var defaultAllowOrigins = []string{ | |||
| 	"0.0.0.0", | ||||
| } | ||||
| 
 | ||||
| func Serve(ln net.Listener, allowOrigins []string) error { | ||||
| 	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { | ||||
| 		// clean up unused layers and manifests
 | ||||
| 		if err := PruneLayers(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		manifestsPath, err := GetManifestPath() | ||||
| func NewServer() (*Server, error) { | ||||
| 	workDir, err := os.MkdirTemp("", "ollama") | ||||
| 	if err != nil { | ||||
| 			return err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 		if err := PruneDirectory(manifestsPath); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	return &Server{ | ||||
| 		WorkDir: workDir, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) GenerateRoutes() http.Handler { | ||||
| 	var origins []string | ||||
| 	if o := os.Getenv("OLLAMA_ORIGINS"); o != "" { | ||||
| 		origins = strings.Split(o, ",") | ||||
| 	} | ||||
| 
 | ||||
| 	config := cors.DefaultConfig() | ||||
| 	config.AllowWildcard = true | ||||
| 
 | ||||
| 	config.AllowOrigins = allowOrigins | ||||
| 	config.AllowOrigins = origins | ||||
| 	for _, allowOrigin := range defaultAllowOrigins { | ||||
| 		config.AllowOrigins = append(config.AllowOrigins, | ||||
| 			fmt.Sprintf("http://%s", allowOrigin), | ||||
|  | @ -830,17 +834,11 @@ func Serve(ln net.Listener, allowOrigins []string) error { | |||
| 		) | ||||
| 	} | ||||
| 
 | ||||
| 	workDir, err := os.MkdirTemp("", "ollama") | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer os.RemoveAll(workDir) | ||||
| 
 | ||||
| 	r := gin.Default() | ||||
| 	r.Use( | ||||
| 		cors.New(config), | ||||
| 		func(c *gin.Context) { | ||||
| 			c.Set("workDir", workDir) | ||||
| 			c.Set("workDir", s.WorkDir) | ||||
| 			c.Next() | ||||
| 		}, | ||||
| 	) | ||||
|  | @ -868,8 +866,34 @@ func Serve(ln net.Listener, allowOrigins []string) error { | |||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	return r | ||||
| } | ||||
| 
 | ||||
| func Serve(ln net.Listener) error { | ||||
| 	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { | ||||
| 		// clean up unused layers and manifests
 | ||||
| 		if err := PruneLayers(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		manifestsPath, err := GetManifestPath() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		if err := PruneDirectory(manifestsPath); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	s, err := NewServer() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	r := s.GenerateRoutes() | ||||
| 
 | ||||
| 	log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version) | ||||
| 	s := &http.Server{ | ||||
| 	srvr := &http.Server{ | ||||
| 		Handler: r, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -881,7 +905,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { | |||
| 		if loaded.runner != nil { | ||||
| 			loaded.runner.Close() | ||||
| 		} | ||||
| 		os.RemoveAll(workDir) | ||||
| 		os.RemoveAll(s.WorkDir) | ||||
| 		os.Exit(0) | ||||
| 	}() | ||||
| 
 | ||||
|  | @ -892,7 +916,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return s.Serve(ln) | ||||
| 	return srvr.Serve(ln) | ||||
| } | ||||
| 
 | ||||
| func waitForStream(c *gin.Context, ch chan interface{}) { | ||||
|  |  | |||
|  | @ -0,0 +1,70 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func setupServer(t *testing.T) (*Server, error) { | ||||
| 	t.Helper() | ||||
| 
 | ||||
| 	return NewServer() | ||||
| } | ||||
| 
 | ||||
| func Test_Routes(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		Name     string | ||||
| 		Method   string | ||||
| 		Path     string | ||||
| 		Setup    func(t *testing.T, req *http.Request) | ||||
| 		Expected func(t *testing.T, resp *http.Response) | ||||
| 	} | ||||
| 
 | ||||
| 	testCases := []testCase{ | ||||
| 		{ | ||||
| 			Name:   "Version Handler", | ||||
| 			Method: http.MethodGet, | ||||
| 			Path:   "/api/version", | ||||
| 			Setup: func(t *testing.T, req *http.Request) { | ||||
| 			}, | ||||
| 			Expected: func(t *testing.T, resp *http.Response) { | ||||
| 				contentType := resp.Header.Get("Content-Type") | ||||
| 				assert.Equal(t, contentType, "application/json; charset=utf-8") | ||||
| 				body, err := io.ReadAll(resp.Body) | ||||
| 				assert.Nil(t, err) | ||||
| 				assert.Equal(t, `{"version":"0.0.0"}`, string(body)) | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	s, err := setupServer(t) | ||||
| 	assert.Nil(t, err) | ||||
| 
 | ||||
| 	router := s.GenerateRoutes() | ||||
| 
 | ||||
| 	httpSrv := httptest.NewServer(router) | ||||
| 	t.Cleanup(httpSrv.Close) | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		u := httpSrv.URL + tc.Path | ||||
| 		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) | ||||
| 		assert.Nil(t, err) | ||||
| 
 | ||||
| 		if tc.Setup != nil { | ||||
| 			tc.Setup(t, req) | ||||
| 		} | ||||
| 
 | ||||
| 		resp, err := httpSrv.Client().Do(req) | ||||
| 		assert.Nil(t, err) | ||||
| 
 | ||||
| 		if tc.Expected != nil { | ||||
| 			tc.Expected(t, resp) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
		Loading…
	
		Reference in New Issue