package backend

import (
	"context"
	"database/sql"
	"fmt"
	"log"

	"golang.org/x/crypto/bcrypt"
)

const (
	Admin = iota
	Publisher
	Editor
	Author
	NonExistent
)

type User struct {
	UserName  string
	FirstName string
	LastName  string
	ID        int64
	Role      int
}

func (db *DB) AddUser(u *User, pass string) (int64, error) {
	hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
	if err != nil {
		return 0, fmt.Errorf("error creating password hash: %v", err)
	}

	query := `
    INSERT INTO users (username, password, first_name, last_name, role)
    VALUES (?, ?, ?, ?, ?)
    `
	result, err := db.Exec(query, u.UserName, string(hashedPass), u.FirstName, u.LastName, u.Role)
	if err != nil {
		return 0, fmt.Errorf("error inserting new user %v into DB: %v", u.UserName, err)
	}
	id, err := result.LastInsertId()
	if err != nil {
		return 0, fmt.Errorf("error inserting user into DB: %v", err)
	}

	return id, nil
}

func (db *DB) GetID(userName string) (int64, bool) {
	var id int64

	query := `
    SELECT id
    FROM users
    WHERE username = ?
    `
	row := db.QueryRow(query, userName)
	if err := row.Scan(&id); err != nil {
		return 0, false
	}

	return id, true
}

func (db *DB) CheckPassword(id int64, pass string) error {
	var queriedPass string

	query := `
    SELECT password
    FROM users
    WHERE id = ?
    `
	row := db.QueryRow(query, id)
	if err := row.Scan(&queriedPass); err != nil {
		return fmt.Errorf("error reading password from DB: %v", err)
	}

	if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil {
		return fmt.Errorf("incorrect password: %v", err)
	}

	return nil
}

func (tx *Tx) ChangePassword(id int64, oldPass, newPass string) error {
	var queriedPass string
	getQuery := `
    SELECT password
    FROM users
    WHERE id = ?
    `
	row := tx.QueryRow(getQuery, id)
	if err := row.Scan(&queriedPass); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error reading password from DB: %v", err)
	}

	if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("incorrect password: %v", err)
	}

	newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
	if err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error creating password hash: %v", err)
	}

	setQuery := `
    UPDATE users
    SET password = ?
    WHERE id = ?
    `
	if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error updating password in DB: %v", err)
	}

	return nil
}

// TODO: No need for ID field in general
func (db *DB) GetUser(id int64) (*User, error) {
	user := new(User)
	query := `
    SELECT id, username, first_name, last_name, role
    FROM users
    WHERE id = ?
    `

	row := db.QueryRow(query, id)
	if err := row.Scan(&user.ID, &user.UserName, &user.FirstName,
		&user.LastName, &user.Role); err != nil {
		return nil, fmt.Errorf("error reading user information: %v", err)
	}

	return user, nil
}

func (db *DB) UpdateOwnAttributes(id int64, user, first, last, oldPass, newPass, newPass2 string) error {
	passwordEmpty := true
	if len(newPass) > 0 || len(newPass2) > 0 {
		if newPass != newPass2 {
			return fmt.Errorf("error: passwords do not match")
		}
		passwordEmpty = false
	}

	tx := new(Tx)
	var err error

	for i := 0; i < TxMaxRetries; i++ {
		err := func() error {
			tx.Tx, err = db.Begin()
			if err != nil {
				return fmt.Errorf("error starting transaction: %v", err)
			}

			if !passwordEmpty {
				if err = tx.ChangePassword(id, oldPass, newPass); err != nil {
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
						log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
					}
					return fmt.Errorf("error changing password: %v", err)
				}
			}

			if err = tx.UpdateAttributes(
				&Attribute{Table: "users", ID: id, AttName: "username", Value: user},
				&Attribute{Table: "users", ID: id, AttName: "first_name", Value: first},
				&Attribute{Table: "users", ID: id, AttName: "last_name", Value: last},
			); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error updating attributes in DB: %v", err)
			}

			if err = tx.Commit(); err != nil {
				return fmt.Errorf("error committing transaction: %v", err)
			}

			return nil
		}()
		if err == nil {
			return nil
		}

		log.Println(err)
		wait(i)
	}

	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) AddFirstUser(u *User, pass string) (int64, error) {
	var numUsers int64
	txOptions := &sql.TxOptions{Isolation: sql.LevelSerializable}
	selectQuery := "SELECT COUNT(*) FROM users"
	insertQuery := `
    INSERT INTO users (username, password, first_name, last_name, role)
    VALUES (?, ?, ?, ?, ?)
    `

	for i := 0; i < TxMaxRetries; i++ {
		id, err := func() (int64, error) {
			tx, err := db.BeginTx(context.Background(), txOptions)
			if err != nil {
				return 0, fmt.Errorf("error starting transaction: %v", err)
			}

			if err := tx.QueryRow(selectQuery).Scan(&numUsers); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error getting ID of %v: %v", u.UserName, err)
			}
			if numUsers != 0 {
				if err = tx.Commit(); err != nil {
					return 0, fmt.Errorf("error committing transaction: %v", err)
				}
				return 2, nil
			}

			hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error creating password hash: %v", err)
			}

			result, err := tx.Exec(insertQuery, u.UserName, string(hashedPass), u.FirstName, u.LastName, u.Role)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error inserting new user %v into DB: %v", u.UserName, err)
			}

			id, err := result.LastInsertId()
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error inserting user into DB: %v", err)
			}

			if err = tx.Commit(); err != nil {
				return 0, fmt.Errorf("error committing transaction: %v", err)
			}
			return id, nil
		}()
		if err == nil {
			return id, nil
		}

		log.Println(err)
		wait(i)
	}
	return 0, fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) GetAllUsers() (map[int64]*User, error) {
	query := "SELECT id, username, first_name, last_name, role FROM users"

	rows, err := db.Query(query)
	if err != nil {
		return nil, fmt.Errorf("error getting all users from DB: %v", err)
	}

	users := make(map[int64]*User, 0)
	for rows.Next() {
		user := new(User)
		if err = rows.Scan(&user.ID, &user.UserName, &user.FirstName,
			&user.LastName, &user.Role); err != nil {
			return nil, fmt.Errorf("error getting user info: %v", err)
		}
		users[user.ID] = user
	}

	return users, nil
}

func (tx *Tx) SetPassword(id int64, newPass string) error {
	hashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
	if err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error creating password hash: %v", err)
	}

	setQuery := "UPDATE users SET password = ? WHERE id = ?"
	if _, err = tx.Exec(setQuery, string(hashedPass), id); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error updating password in DB: %v", err)
	}

	return nil
}

func (db *DB) UpdateUserAttributes(id int64, user, first, last, newPass, newPass2 string, role int) error {
	passwordEmpty := true
	if len(newPass) > 0 || len(newPass2) > 0 {
		if newPass != newPass2 {
			return fmt.Errorf("error: passwords do not match")
		}
		passwordEmpty = false
	}

	tx := new(Tx)
	var err error

	for i := 0; i < TxMaxRetries; i++ {
		err := func() error {
			tx.Tx, err = db.Begin()
			if err != nil {
				return fmt.Errorf("error starting transaction: %v", err)
			}

			if !passwordEmpty {
				if err = tx.SetPassword(id, newPass); err != nil {
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
						log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
					}
					return fmt.Errorf("error changing password: %v", err)
				}
			}

			if err = tx.UpdateAttributes(
				&Attribute{Table: "users", ID: id, AttName: "username", Value: user},
				&Attribute{Table: "users", ID: id, AttName: "first_name", Value: first},
				&Attribute{Table: "users", ID: id, AttName: "last_name", Value: last},
				&Attribute{Table: "users", ID: id, AttName: "role", Value: role},
			); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error updating attributes in DB: %v", err)
			}

			if err = tx.Commit(); err != nil {
				return fmt.Errorf("error committing transaction: %v", err)
			}

			return nil
		}()
		if err == nil {
			return nil
		}

		log.Println(err)
		wait(i)
	}

	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) DeleteUser(id int64) error {
	query := "DELETE FROM users WHERE id = ?"

	_, err := db.Exec(query, id)
	if err != nil {
		return fmt.Errorf("error deleting user %v from DB: %v", id, err)
	}

	return nil
}