diff --git a/main.go b/main.go index 9d1b4f2..9092167 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "os" + "time" "git.1e99.eu/1e99/passed/middlewares" "git.1e99.eu/1e99/passed/routes" @@ -25,7 +26,14 @@ func run() error { mux := http.NewServeMux() mux.Handle("GET /", routes.ServeFiles(embedFS, "static")) - mux.Handle("POST /api/password", routes.CreatePassword(storage, 12*1024, base64.StdEncoding)) + mux.Handle( + "POST /api/password", + middlewares.RateLimiter( + routes.CreatePassword(storage, 12*1024, base64.StdEncoding), + 1*time.Minute, + 5, + ), + ) mux.Handle("GET /api/password/{id}", routes.GetPassword(storage, base64.StdEncoding)) mux.Handle("HEAD /api/password/{id}", routes.HasPassword(storage)) diff --git a/middlewares/logger.go b/middlewares/logger.go index 869b00d..10b474e 100644 --- a/middlewares/logger.go +++ b/middlewares/logger.go @@ -7,7 +7,7 @@ import ( func Logger(handler http.Handler) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { - log.Printf("%-80s", req.URL.Path) + log.Printf("%-30s %-80s", req.RemoteAddr, req.URL.Path) handler.ServeHTTP(res, req) } } diff --git a/middlewares/rate_limiter.go b/middlewares/rate_limiter.go new file mode 100644 index 0000000..9640d8e --- /dev/null +++ b/middlewares/rate_limiter.go @@ -0,0 +1,44 @@ +package middlewares + +import ( + "net/http" + "sync" + "time" +) + +func RateLimiter(handler http.Handler, clearInterval time.Duration, maxRequests int) http.HandlerFunc { + requests := make(map[string]int) + lock := sync.Mutex{} + + ticker := time.NewTicker(clearInterval) + go func() { + for { + <-ticker.C + + lock.Lock() + clear(requests) + lock.Unlock() + } + }() + + return func(res http.ResponseWriter, req *http.Request) { + addr := req.RemoteAddr + + lock.Lock() + count, found := requests[addr] + if !found { + count = 0 + } + + count += 1 + requests[addr] = count + lock.Unlock() + + if count > maxRequests { + http.Error(res, "Too many requests", http.StatusTooManyRequests) + return + } + + handler.ServeHTTP(res, req) + } +}