| 
									
										
										
										
											2020-12-09 23:22:24 +08:00
										 |  |  | package middleware | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"net/http/httptest" | 
					
						
							|  |  |  | 	"testing" | 
					
						
							|  |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/grafana/grafana/pkg/models" | 
					
						
							| 
									
										
										
										
											2020-12-11 18:44:44 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/setting" | 
					
						
							| 
									
										
										
										
											2020-12-09 23:22:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-11 20:30:59 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/web" | 
					
						
							| 
									
										
										
										
											2020-12-09 23:22:24 +08:00
										 |  |  | 	"github.com/stretchr/testify/assert" | 
					
						
							|  |  |  | 	"github.com/stretchr/testify/require" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type execFunc func() *httptest.ResponseRecorder | 
					
						
							|  |  |  | type advanceTimeFunc func(deltaTime time.Duration) | 
					
						
							|  |  |  | type rateLimiterScenarioFunc func(c execFunc, t advanceTimeFunc) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateLimiterScenarioFunc) { | 
					
						
							| 
									
										
										
										
											2020-12-11 18:44:44 +08:00
										 |  |  | 	t.Helper() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 23:22:24 +08:00
										 |  |  | 	t.Run(desc, func(t *testing.T) { | 
					
						
							|  |  |  | 		defaultHandler := func(c *models.ReqContext) { | 
					
						
							|  |  |  | 			resp := make(map[string]interface{}) | 
					
						
							|  |  |  | 			resp["message"] = "OK" | 
					
						
							|  |  |  | 			c.JSON(200, resp) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		currentTime := time.Now() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-11 18:44:44 +08:00
										 |  |  | 		cfg := setting.NewCfg() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-11 20:30:59 +08:00
										 |  |  | 		m := web.New() | 
					
						
							|  |  |  | 		m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]")) | 
					
						
							| 
									
										
										
										
											2020-12-11 18:44:44 +08:00
										 |  |  | 		m.Use(getContextHandler(t, cfg).Middleware) | 
					
						
							| 
									
										
										
										
											2020-12-09 23:22:24 +08:00
										 |  |  | 		m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		fn(func() *httptest.ResponseRecorder { | 
					
						
							|  |  |  | 			resp := httptest.NewRecorder() | 
					
						
							|  |  |  | 			req, err := http.NewRequest("GET", "/foo", nil) | 
					
						
							|  |  |  | 			require.NoError(t, err) | 
					
						
							|  |  |  | 			m.ServeHTTP(resp, req) | 
					
						
							|  |  |  | 			return resp | 
					
						
							|  |  |  | 		}, func(deltaTime time.Duration) { | 
					
						
							|  |  |  | 			currentTime = currentTime.Add(deltaTime) | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestRateLimitMiddleware(t *testing.T) { | 
					
						
							|  |  |  | 	rateLimiterScenario(t, "rate limit calls, with burst", 10, 10, func(doReq execFunc, advanceTime advanceTimeFunc) { | 
					
						
							|  |  |  | 		// first 10 calls succeed
 | 
					
						
							|  |  |  | 		for i := 0; i < 10; i++ { | 
					
						
							|  |  |  | 			resp := doReq() | 
					
						
							|  |  |  | 			assert.Equal(t, 200, resp.Code) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// next one fails
 | 
					
						
							|  |  |  | 		resp := doReq() | 
					
						
							|  |  |  | 		assert.Equal(t, 429, resp.Code) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// check that requests are accepted again in 1 sec
 | 
					
						
							|  |  |  | 		advanceTime(1 * time.Second) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		for i := 0; i < 10; i++ { | 
					
						
							|  |  |  | 			resp := doReq() | 
					
						
							|  |  |  | 			assert.Equal(t, 200, resp.Code) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	rateLimiterScenario(t, "rate limit calls, no burst", 10, 1, func(doReq execFunc, advanceTime advanceTimeFunc) { | 
					
						
							|  |  |  | 		// first calls succeeds
 | 
					
						
							|  |  |  | 		resp := doReq() | 
					
						
							|  |  |  | 		assert.Equal(t, 200, resp.Code) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// immediately fired next one fails
 | 
					
						
							|  |  |  | 		resp = doReq() | 
					
						
							|  |  |  | 		assert.Equal(t, 429, resp.Code) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// but spacing calls out works
 | 
					
						
							|  |  |  | 		for i := 0; i < 10; i++ { | 
					
						
							|  |  |  | 			advanceTime(100 * time.Millisecond) | 
					
						
							|  |  |  | 			resp := doReq() | 
					
						
							|  |  |  | 			assert.Equal(t, 200, resp.Code) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | } |