Compare commits

..

2 commits

Author SHA1 Message Date
1e99
afcd407f1d add rate limit 2024-10-30 10:33:03 +01:00
1e99
7bea159f2b add middlewares package 2024-10-30 10:24:41 +01:00
3 changed files with 57 additions and 4 deletions

13
main.go
View file

@ -6,7 +6,9 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"time"
"git.1e99.eu/1e99/passed/middlewares"
"git.1e99.eu/1e99/passed/routes" "git.1e99.eu/1e99/passed/routes"
"git.1e99.eu/1e99/passed/storage" "git.1e99.eu/1e99/passed/storage"
) )
@ -24,7 +26,14 @@ func run() error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("GET /", routes.ServeFiles(embedFS, "static")) mux.Handle("GET /", routes.ServeFiles(embedFS, "static"))
mux.Handle("POST /api/password", routes.CreatePassword(storage, 12*1024, base64.StdEncoding)) mux.Handle(
"POST /api/password",
middlewares.RateLimiter(
routes.CreatePassword(storage, 12*1024, base64.StdEncoding),
1*time.Minute,
5,
),
)
mux.Handle("GET /api/password/{id}", routes.GetPassword(storage, base64.StdEncoding)) mux.Handle("GET /api/password/{id}", routes.GetPassword(storage, base64.StdEncoding))
mux.Handle("HEAD /api/password/{id}", routes.HasPassword(storage)) mux.Handle("HEAD /api/password/{id}", routes.HasPassword(storage))
@ -34,7 +43,7 @@ func run() error {
address = ":3000" address = ":3000"
} }
err = http.ListenAndServe(address, routes.Logger(mux)) err = http.ListenAndServe(address, middlewares.Logger(mux))
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,4 +1,4 @@
package routes package middlewares
import ( import (
"log" "log"
@ -7,7 +7,7 @@ import (
func Logger(handler http.Handler) http.HandlerFunc { func Logger(handler http.Handler) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) { return func(res http.ResponseWriter, req *http.Request) {
log.Printf("%-80s", req.URL.Path) log.Printf("%-30s %-80s", req.RemoteAddr, req.URL.Path)
handler.ServeHTTP(res, req) handler.ServeHTTP(res, req)
} }
} }

View file

@ -0,0 +1,44 @@
package middlewares
import (
"net/http"
"sync"
"time"
)
func RateLimiter(handler http.Handler, clearInterval time.Duration, maxRequests int) http.HandlerFunc {
requests := make(map[string]int)
lock := sync.Mutex{}
ticker := time.NewTicker(clearInterval)
go func() {
for {
<-ticker.C
lock.Lock()
clear(requests)
lock.Unlock()
}
}()
return func(res http.ResponseWriter, req *http.Request) {
addr := req.RemoteAddr
lock.Lock()
count, found := requests[addr]
if !found {
count = 0
}
count += 1
requests[addr] = count
lock.Unlock()
if count > maxRequests {
http.Error(res, "Too many requests", http.StatusTooManyRequests)
return
}
handler.ServeHTTP(res, req)
}
}