package main import ( "embed" "errors" "io/fs" "log" "net/http" "os" "strconv" "time" "git.1e99.eu/1e99/passed/routes" "git.1e99.eu/1e99/passed/storage" ) //go:embed static/* var embedFS embed.FS func run() error { store, err := newStore() if err != nil { return err } staticFS, err := newStaticFS() if err != nil { return err } var address string var logRequests bool var maxPasswordLength int env("PASSED_ADDRESS", &address, ":3000") env("PASSED_LOG_REQUESTS", &logRequests, "true") env("PASSED_MAX_LENGTH", &maxPasswordLength, "12288") mux := http.NewServeMux() handler := http.Handler(mux) mux.Handle("GET /", http.FileServerFS(staticFS)) mux.Handle("POST /api/password", routes.CreatePassword(store, maxPasswordLength)) mux.Handle("GET /api/password/{id}", routes.GetPassword(store)) mux.Handle("HEAD /api/password/{id}", routes.HasPassword(store)) if logRequests { handler = routes.Logger(handler) } log.Printf("Listening on %s.", address) err = http.ListenAndServe(address, handler) if err != nil { return err } return nil } func newStaticFS() (sfs fs.FS, err error) { var fsType string env("PASSED_STATIC_TYPE", &fsType, "embed") switch fsType { case "embed": sfs, err = fs.Sub(embedFS, "static") return case "dir", "directory": var path string env("PASSED_STATIC_DIR_PATH", &path, "static") sfs = os.DirFS(path) return default: err = errors.New("unkown fs type") return } } func newStore() (store storage.Store, err error) { var storeType string var clearInterval int env("PASSED_STORE_TYPE", &storeType, "ram") env("PASSED_STORE_CLEAR_INTERVAL", &clearInterval, "30") switch storeType { case "ram": store = storage.NewRamStore() case "dir", "directory": var path string env("PASSED_STORE_DIR_PATH", &path, "passwords") err = os.MkdirAll(path, os.ModePerm) if err != nil { return } store = storage.NewDirStore(path) default: err = errors.New("unknown storage type") return } go func() { ticker := time.Tick(time.Duration(clearInterval) * time.Second) for { <-ticker err := store.ClearExpired() if err != nil { log.Printf("Failed to clear expired passwords: %s", err) continue } log.Printf("Cleared expired passwords.") } }() return } func main() { err := run() if err != nil { log.Fatalf("%s", err) } } func env(name string, out any, def string) { raw := os.Getenv(name) if raw == "" { raw = def log.Printf("No \"%s\" provided, defaulting to \"%s\".", name, def) } switch value := out.(type) { case *int: i, err := strconv.ParseInt(raw, 10, 64) if err != nil { log.Printf("\"%s\" is not a number (\"%s\").", name, raw) return } *value = int(i) case *bool: switch raw { case "true", "TRUE", "1": *value = true case "false", "FALSE", "0": *value = false } case *string: *value = raw } }