diff --git a/cmd/data/db.go b/cmd/data/db.go index 361678c..5961a34 100644 --- a/cmd/data/db.go +++ b/cmd/data/db.go @@ -94,9 +94,21 @@ func (db *DB) ChangePassword(id int64, oldPass, newPass string) error { return fmt.Errorf("error starting transaction: %v", err) } - if err := db.CheckPassword(id, oldPass); err != nil { + var queriedPass string + getQuery := ` + SELECT password + FROM users + WHERE id = ? + ` + row := tx.QueryRow(getQuery, id) + if err := row.Scan(&queriedPass); err != nil { tx.Rollback() - return fmt.Errorf("error checking password: %v", err) + return fmt.Errorf("error reading password from DB: %v", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil { + tx.Rollback() + return fmt.Errorf("incorrect password: %v", err) } newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost) @@ -105,12 +117,12 @@ func (db *DB) ChangePassword(id int64, oldPass, newPass string) error { return fmt.Errorf("error creating password hash: %v", err) } - query := ` + setQuery := ` UPDATE users SET password = ? WHERE id = ? ` - if _, err = db.Exec(query, string(newHashedPass), id); err != nil { + if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil { tx.Rollback() return fmt.Errorf("error updating password in DB: %v", err) }