diff --git a/main.go b/main.go index 71ca7af..09c6b0c 100644 --- a/main.go +++ b/main.go @@ -30,7 +30,7 @@ func run() error { address := os.Getenv("PASSED_ADDRESS") if address == "" { - log.Printf("No PASSED_ADDRESS specified, defaulting to \":3000\"") + log.Printf("No PASSED_ADDRESS provided, defaulting to \":3000\"") address = ":3000" } diff --git a/storage/dir.go b/storage/dir.go new file mode 100644 index 0000000..884b0dd --- /dev/null +++ b/storage/dir.go @@ -0,0 +1,169 @@ +package storage + +import ( + "encoding/gob" + "log" + "os" + "path" + "time" +) + +func NewDirStore(clearInterval time.Duration, path string) Store { + store := &dir{ + clearInterval: clearInterval, + timeLayout: time.RFC3339Nano, + path: path, + close: make(chan bool), + } + + go store.clearExpired() + return store +} + +type dir struct { + clearInterval time.Duration + timeLayout string + path string + close chan bool +} + +func (store *dir) CreatePassword(password []byte, expiresAt time.Time) (string, error) { + for range 1000 { + id := generateId(24) + path := store.getPath(id) + + file, err := os.OpenFile( + path, + os.O_CREATE|os.O_EXCL|os.O_WRONLY, + os.ModePerm, + ) + switch { + case os.IsExist(err): + continue + case err != nil: + return "", err + } + + defer file.Close() + + entry := entry{ + Password: password, + ExpiresAt: expiresAt, + } + + err = gob.NewEncoder(file).Encode(&entry) + if err != nil { + log.Printf("%s", err) + return "", err + } + + return id, nil + } + + return "", ErrFull +} + +func (store *dir) GetPassword(id string) ([]byte, error) { + path := store.getPath(id) + file, err := os.OpenFile( + path, + os.O_RDONLY, + 0, + ) + switch { + case os.IsNotExist(err): + return nil, ErrNotFound + case err != nil: + return nil, err + } + + defer file.Close() + + var entry entry + err = gob.NewDecoder(file).Decode(&entry) + if err != nil { + return nil, err + } + + // Close file early as we need to delete it + file.Close() + err = os.Remove(path) + if err != nil { + return nil, err + } + + return entry.Password, nil +} + +func (store *dir) HasPassword(id string) (bool, error) { + path := store.getPath(id) + _, err := os.Stat(path) + switch { + case os.IsNotExist(err): + return false, nil + case err != nil: + return false, err + } + + return true, nil +} + +func (store *dir) Close() error { + store.close <- true + return nil +} + +func (store *dir) clearExpired() error { + ticker := time.NewTicker(store.clearInterval) + defer ticker.Stop() + + for { + select { + case <-store.close: + return nil + case <-ticker.C: + // TODO: Error handling? + now := time.Now() + + entries, err := os.ReadDir(store.path) + if err != nil { + continue + } + + for _, file := range entries { + id := file.Name() + path := store.getPath(id) + file, err := os.OpenFile( + path, + os.O_RDONLY, + 0, + ) + if err != nil { + continue + } + + defer file.Close() + + var entry entry + err = gob.NewDecoder(file).Decode(&entry) + if err != nil { + continue + } + + if now.After(entry.ExpiresAt) { + // Close file early as we need to delete it + file.Close() + + err := os.Remove(path) + if err != nil { + continue + } + } + } + } + } +} + +func (store *dir) getPath(id string) string { + return path.Join(store.path, id) +} diff --git a/storage/ram.go b/storage/ram.go index e19e68d..67941de 100644 --- a/storage/ram.go +++ b/storage/ram.go @@ -1,7 +1,6 @@ package storage import ( - "math/rand/v2" "sync" "time" ) @@ -30,15 +29,15 @@ func (store *ram) CreatePassword(password []byte, expiresAt time.Time) (string, defer store.lock.Unlock() for range 1000 { - id := store.generateId(24) + id := generateId(24) _, found := store.passwords[id] if found { continue } store.passwords[id] = entry{ - password: password, - expiresAt: expiresAt, + Password: password, + ExpiresAt: expiresAt, } return id, nil @@ -57,7 +56,7 @@ func (store *ram) GetPassword(id string) ([]byte, error) { } delete(store.passwords, id) - return password.password, nil + return password.Password, nil } func (store *ram) HasPassword(id string) (bool, error) { @@ -86,7 +85,7 @@ func (store *ram) clearExpired() error { time := time.Now() for id, password := range store.passwords { - if time.After(password.expiresAt) { + if time.After(password.ExpiresAt) { delete(store.passwords, id) } } @@ -95,14 +94,3 @@ func (store *ram) clearExpired() error { } } } - -func (store *ram) generateId(length int) string { - runes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - - str := make([]rune, length) - for i := range str { - str[i] = runes[rand.IntN(len(runes))] - } - - return string(str) -} diff --git a/storage/storage.go b/storage/storage.go index a968ab4..aa398d1 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,6 +3,7 @@ package storage import ( "errors" "log" + "math/rand/v2" "os" "strings" "time" @@ -14,8 +15,8 @@ var ( ) type entry struct { - password []byte - expiresAt time.Time + Password []byte + ExpiresAt time.Time } type Store interface { @@ -32,9 +33,28 @@ func NewStore() (Store, error) { switch storeType { case "ram": return NewRamStore(20 * time.Second), nil + case "dir": + path := os.Getenv("PASSED_STORE_DIR_PATH") + if path == "" { + log.Printf("No PASSED_STORE_DIR_PATH provided, defaulting to \"passwords\".") + path = "passwords" + } + + return NewDirStore(60*time.Second, path), nil default: log.Printf("No PASSED_STORE_TYPE provided, defaulting to memory store.") return NewRamStore(20 * time.Second), nil } } + +func generateId(length int) string { + runes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + + str := make([]rune, length) + for i := range str { + str[i] = runes[rand.IntN(len(runes))] + } + + return string(str) +}