diff --git a/routes/create_password.go b/routes/create_password.go index 5464074..b5b3bd0 100644 --- a/routes/create_password.go +++ b/routes/create_password.go @@ -49,7 +49,7 @@ func CreatePassword(store storage.Store, maxLength int, encoding *base64.Encodin return } - id, err := store.CreatePassword( + id, err := store.Create( password, time.Now().Add(expiresIn), ) diff --git a/routes/get_password.go b/routes/get_password.go index ed7d3d4..9495c1e 100644 --- a/routes/get_password.go +++ b/routes/get_password.go @@ -11,7 +11,7 @@ import ( func GetPassword(store storage.Store, encoding *base64.Encoding) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { id := req.PathValue("id") - password, err := store.GetPassword(id) + password, err := store.Get(id) switch { case err == storage.ErrNotFound: http.Error(res, "Password not found", http.StatusNotFound) @@ -21,8 +21,13 @@ func GetPassword(store storage.Store, encoding *base64.Encoding) http.HandlerFun return } - encodedPassword := encoding.EncodeToString(password) + err = store.Delete(id) + if err != nil { + http.Error(res, "", http.StatusInternalServerError) + return + } + encodedPassword := encoding.EncodeToString(password) resBody := struct { Password string `json:"password"` }{ diff --git a/routes/has_password.go b/routes/has_password.go index 25f4579..34e8d4b 100644 --- a/routes/has_password.go +++ b/routes/has_password.go @@ -9,16 +9,16 @@ import ( func HasPassword(store storage.Store) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { id := req.PathValue("id") - found, err := store.HasPassword(id) - if err != nil { + _, err := store.Get(id) + switch { + case err == storage.ErrNotFound: + http.Error(res, "", http.StatusNotFound) + return + case err != nil: http.Error(res, "", http.StatusInternalServerError) return } - if found { - res.WriteHeader(http.StatusNoContent) - } else { - res.WriteHeader(http.StatusNotFound) - } + res.WriteHeader(http.StatusNoContent) } } diff --git a/storage/dir.go b/storage/dir.go index 884b0dd..aab0d5c 100644 --- a/storage/dir.go +++ b/storage/dir.go @@ -27,7 +27,7 @@ type dir struct { close chan bool } -func (store *dir) CreatePassword(password []byte, expiresAt time.Time) (string, error) { +func (store *dir) Create(password []byte, expiresAt time.Time) (string, error) { for range 1000 { id := generateId(24) path := store.getPath(id) @@ -63,7 +63,7 @@ func (store *dir) CreatePassword(password []byte, expiresAt time.Time) (string, return "", ErrFull } -func (store *dir) GetPassword(id string) ([]byte, error) { +func (store *dir) Get(id string) ([]byte, error) { path := store.getPath(id) file, err := os.OpenFile( path, @@ -85,27 +85,17 @@ func (store *dir) GetPassword(id string) ([]byte, error) { return nil, err } - // Close file early as we need to delete it - file.Close() - err = os.Remove(path) - if err != nil { - return nil, err - } - return entry.Password, nil } -func (store *dir) HasPassword(id string) (bool, error) { +func (store *dir) Delete(id string) error { path := store.getPath(id) - _, err := os.Stat(path) - switch { - case os.IsNotExist(err): - return false, nil - case err != nil: - return false, err + err := os.Remove(path) + if err != nil { + return nil } - return true, nil + return nil } func (store *dir) Close() error { @@ -153,11 +143,7 @@ func (store *dir) clearExpired() error { if now.After(entry.ExpiresAt) { // Close file early as we need to delete it file.Close() - - err := os.Remove(path) - if err != nil { - continue - } + store.Delete(id) } } } diff --git a/storage/ram.go b/storage/ram.go index 67941de..774fa8b 100644 --- a/storage/ram.go +++ b/storage/ram.go @@ -24,7 +24,7 @@ type ram struct { close chan bool } -func (store *ram) CreatePassword(password []byte, expiresAt time.Time) (string, error) { +func (store *ram) Create(password []byte, expiresAt time.Time) (string, error) { store.lock.Lock() defer store.lock.Unlock() @@ -46,7 +46,7 @@ func (store *ram) CreatePassword(password []byte, expiresAt time.Time) (string, return "", ErrFull } -func (store *ram) GetPassword(id string) ([]byte, error) { +func (store *ram) Get(id string) ([]byte, error) { store.lock.Lock() defer store.lock.Unlock() @@ -55,16 +55,12 @@ func (store *ram) GetPassword(id string) ([]byte, error) { return nil, ErrNotFound } - delete(store.passwords, id) return password.Password, nil } -func (store *ram) HasPassword(id string) (bool, error) { - store.lock.Lock() - defer store.lock.Unlock() - - _, found := store.passwords[id] - return found, nil +func (store *ram) Delete(id string) error { + delete(store.passwords, id) + return nil } func (store *ram) Close() error { @@ -86,7 +82,7 @@ func (store *ram) clearExpired() error { for id, password := range store.passwords { if time.After(password.ExpiresAt) { - delete(store.passwords, id) + store.Delete(id) } } diff --git a/storage/storage.go b/storage/storage.go index aa398d1..db3f316 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -20,9 +20,9 @@ type entry struct { } type Store interface { - CreatePassword(password []byte, expiresAt time.Time) (string, error) - GetPassword(id string) ([]byte, error) - HasPassword(id string) (bool, error) + Create(password []byte, expiresAt time.Time) (string, error) + Get(id string) ([]byte, error) + Delete(id string) error Close() error }