cpolis/cmd/model/users.go

201 lines
5.0 KiB
Go
Raw Permalink Normal View History

2024-03-09 11:06:03 +01:00
package model
import (
"fmt"
"log"
"golang.org/x/crypto/bcrypt"
)
2024-03-11 21:08:27 +01:00
const (
Admin = iota
2024-03-28 07:00:37 +01:00
Publisher
2024-03-11 21:08:27 +01:00
Editor
2024-03-28 07:00:37 +01:00
Author
2024-03-11 21:08:27 +01:00
)
type User struct {
UserName string
FirstName string
LastName string
ID int64
Role int
}
2024-03-28 07:00:37 +01:00
func (db *DB) AddUser(u *User, pass string) (int64, error) {
2024-03-09 11:06:03 +01:00
hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
if err != nil {
2024-03-28 07:00:37 +01:00
return 0, fmt.Errorf("error creating password hash: %v", err)
2024-03-09 11:06:03 +01:00
}
query := `
INSERT INTO users (username, password, first_name, last_name, role)
2024-03-09 11:06:03 +01:00
VALUES (?, ?, ?, ?, ?)
`
2024-03-28 07:00:37 +01:00
result, err := db.Exec(query, u.UserName, string(hashedPass), u.FirstName, u.LastName, u.Role)
if err != nil {
return 0, fmt.Errorf("error inserting new user %v into DB: %v", u.UserName, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("error inserting user into DB: %v", err)
2024-03-09 11:06:03 +01:00
}
2024-03-28 07:00:37 +01:00
return id, nil
2024-03-09 11:06:03 +01:00
}
2024-03-11 21:08:27 +01:00
func (db *DB) GetID(userName string) (int64, bool) {
2024-03-09 11:06:03 +01:00
var id int64
query := `
SELECT id
FROM users
WHERE username = ?
`
row := db.QueryRow(query, userName)
if err := row.Scan(&id); err != nil {
2024-03-11 21:08:27 +01:00
return 0, false
2024-03-09 11:06:03 +01:00
}
2024-03-11 21:08:27 +01:00
return id, true
2024-03-09 11:06:03 +01:00
}
func (db *DB) CheckPassword(id int64, pass string) error {
var queriedPass string
query := `
SELECT password
FROM users
WHERE id = ?
`
row := db.QueryRow(query, id)
if err := row.Scan(&queriedPass); err != nil {
return fmt.Errorf("error reading password from DB: %v", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil {
return fmt.Errorf("incorrect password: %v", err)
}
return nil
}
2024-03-11 21:08:27 +01:00
func (tx *Tx) ChangePassword(id int64, oldPass, newPass string) error {
2024-03-09 11:06:03 +01:00
var queriedPass string
getQuery := `
SELECT password
FROM users
WHERE id = ?
`
row := tx.QueryRow(getQuery, id)
if err := row.Scan(&queriedPass); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
2024-03-28 07:00:37 +01:00
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
2024-03-09 11:06:03 +01:00
}
return fmt.Errorf("error reading password from DB: %v", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
2024-03-28 07:00:37 +01:00
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
2024-03-09 11:06:03 +01:00
}
return fmt.Errorf("incorrect password: %v", err)
}
newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
2024-03-28 07:00:37 +01:00
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
2024-03-09 11:06:03 +01:00
}
return fmt.Errorf("error creating password hash: %v", err)
}
setQuery := `
UPDATE users
SET password = ?
WHERE id = ?
`
if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
2024-03-28 07:00:37 +01:00
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
2024-03-09 11:06:03 +01:00
}
return fmt.Errorf("error updating password in DB: %v", err)
}
return nil
}
// TODO: No need for ID field in general
func (db *DB) GetUser(id int64) (*User, error) {
user := new(User)
query := `
SELECT id, username, first_name, last_name, role
FROM users
WHERE id = ?
`
row := db.QueryRow(query, id)
2024-03-28 07:00:37 +01:00
if err := row.Scan(&user.ID, &user.UserName, &user.FirstName,
&user.LastName, &user.Role); err != nil {
2024-03-09 11:06:03 +01:00
return nil, fmt.Errorf("error reading user information: %v", err)
}
return user, nil
}
func (db *DB) UpdateUserAttributes(id int64, user, first, last, oldPass, newPass, newPass2 string) error {
passwordEmpty := true
if len(newPass) > 0 || len(newPass2) > 0 {
if newPass != newPass2 {
return fmt.Errorf("error: passwords do not match")
}
passwordEmpty = false
}
tx := new(Tx)
var err error
for i := 0; i < TxMaxRetries; i++ {
err := func() error {
tx.Tx, err = db.Begin()
if err != nil {
return fmt.Errorf("error starting transaction: %v", err)
}
if !passwordEmpty {
if err = tx.ChangePassword(id, oldPass, newPass); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
2024-03-28 07:00:37 +01:00
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
}
return fmt.Errorf("error changing password: %v", err)
}
}
if err = tx.UpdateAttributes(
&Attribute{Table: "users", ID: id, AttName: "username", Value: user},
&Attribute{Table: "users", ID: id, AttName: "first_name", Value: first},
&Attribute{Table: "users", ID: id, AttName: "last_name", Value: last},
); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
2024-03-28 07:00:37 +01:00
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
}
return fmt.Errorf("error updating attributes in DB: %v", err)
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("error committing transaction: %v", err)
}
return nil
}()
if err == nil {
return nil
}
log.Println(err)
2024-03-28 07:00:37 +01:00
wait(i)
}
return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}