package backend

import (
	"context"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"database/sql"
	"encoding/base64"
	"errors"
	"fmt"
	"io"
	"log"
	"os"

	"golang.org/x/crypto/bcrypt"
)

const (
	Admin = iota
	Publisher
	Editor
	Author
	NonExistent
)

type User struct {
	UserName       string
	FirstName      string
	LastName       string
	Email          string
	ProfilePicLink string
	ID             int64
	Role           int
}

func readKey(filename string) ([]byte, error) {
	key, err := os.ReadFile(filename)
	if err != nil {
		return nil, fmt.Errorf("error reading from aes key file: %v", err)
	}

	if len(key) != 44 {
		return nil, errors.New("key is not 32 bytes long")
	}

	key, err = base64.StdEncoding.DecodeString(string(key))
	if err != nil {
		return nil, fmt.Errorf("error base64 decoding key: %v", err)
	}

	return key, nil
}

func key(c *Config) ([]byte, error) {
	key, err := readKey(c.AESKeyFile)
	if err != nil {
		key = make([]byte, 32)
		if _, err := rand.Read(key); err != nil {
			return nil, fmt.Errorf("error generating random key: %v", err)
		}

		fileKey := make([]byte, 44)
		base64.StdEncoding.Encode(fileKey, key)
		if err = os.WriteFile(c.AESKeyFile, fileKey, 0600); err != nil {
			return nil, fmt.Errorf("error writing key to file: %v", err)
		}
	}

	return key, nil
}

func aesEncrypt(c *Config, plaintext string) (string, error) {
	key, err := key(c)
	if err != nil {
		return "", fmt.Errorf("error retrieving key: %v", err)
	}

	block, err := aes.NewCipher(key)
	if err != nil {
		return "", fmt.Errorf("error creating cipher block: %v", err)
	}

	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return "", fmt.Errorf("error creating new gcm: %v", err)
	}

	nonce := make([]byte, gcm.NonceSize())
	if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
		return "", fmt.Errorf("error creating nonce: %v", err)
	}

	cipherText := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
	return base64.StdEncoding.EncodeToString(cipherText), nil
}

func aesDecrypt(c *Config, ciphertext string) (string, error) {
	key, err := key(c)
	if err != nil {
		return "", fmt.Errorf("error retrieving key: %v", err)
	}

	block, err := aes.NewCipher(key)
	if err != nil {
		return "", fmt.Errorf("error creating cipher block: %v", err)
	}

	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return "", fmt.Errorf("error creating new gcm: %v", err)
	}

	data, err := base64.StdEncoding.DecodeString(ciphertext)
	if err != nil {
		return "", fmt.Errorf("error base64 decoding ciphertext: %v", err)
	}

	nonceSize := gcm.NonceSize()
	nonce, cipherText := data[:nonceSize], data[nonceSize:]

	plaintext, err := gcm.Open(nil, nonce, cipherText, nil)
	if err != nil {
		return "", fmt.Errorf("error aes decoding ciphertext: %v", err)
	}

	return string(plaintext), nil
}

func (db *DB) AddUser(c *Config, u *User, pass string) (int64, error) {
	hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
	if err != nil {
		return 0, fmt.Errorf("error creating password hash: %v", err)
	}

	aesFirstName, err := aesEncrypt(c, u.FirstName)
	if err != nil {
		return 0, fmt.Errorf("error encrypting first name: %v", err)
	}

	aesLastName, err := aesEncrypt(c, u.LastName)
	if err != nil {
		return 0, fmt.Errorf("error encrypting last name: %v", err)
	}

	aesEmail, err := aesEncrypt(c, u.Email)
	if err != nil {
		return 0, fmt.Errorf("error encrypting email: %v", err)
	}

	query := `
    INSERT INTO users (username, password, first_name, last_name, email, profile_pic_link, role)
    VALUES (?, ?, ?, ?, ?, ?, ?)
    `

	result, err := db.Exec(query, u.UserName, string(hashedPass), aesFirstName, aesLastName, aesEmail, u.ProfilePicLink, u.Role)
	if err != nil {
		return 0, fmt.Errorf("error inserting new user %v into DB: %v", u.UserName, err)
	}
	id, err := result.LastInsertId()
	if err != nil {
		return 0, fmt.Errorf("error inserting user into DB: %v", err)
	}

	return id, nil
}

func (db *DB) GetID(userName string) int64 {
	var id int64

	query := `
    SELECT id
    FROM users
    WHERE username = ?
    `
	row := db.QueryRow(query, userName)
	if err := row.Scan(&id); err != nil { // seems like the only possible error is ErrNoRows
		return 0
	}

	return id
}

func (db *DB) CheckPassword(id int64, pass string) error {
	var queriedPass string

	query := `
    SELECT password
    FROM users
    WHERE id = ?
    `
	row := db.QueryRow(query, id)
	if err := row.Scan(&queriedPass); err != nil {
		return fmt.Errorf("error reading password from DB: %v", err)
	}

	if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil {
		return fmt.Errorf("incorrect password: %v", err)
	}

	return nil
}

func (tx *Tx) ChangePassword(id int64, oldPass, newPass string) error {
	var queriedPass string
	getQuery := `
    SELECT password
    FROM users
    WHERE id = ?
    `
	row := tx.QueryRow(getQuery, id)
	if err := row.Scan(&queriedPass); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error reading password from DB: %v", err)
	}

	if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("incorrect password: %v", err)
	}

	newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
	if err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error creating password hash: %v", err)
	}

	setQuery := `
    UPDATE users
    SET password = ?
    WHERE id = ?
    `
	if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error updating password in DB: %v", err)
	}

	return nil
}

// TODO: No need for ID field in general
func (db *DB) GetUser(c *Config, id int64) (*User, error) {
	var aesFirstName, aesLastName, aesEmail string
	var err error

	user := new(User)
	query := `
    SELECT id, username, first_name, last_name, email, profile_pic_link, role
    FROM users
    WHERE id = ?
    `

	row := db.QueryRow(query, id)
	if err := row.Scan(&user.ID, &user.UserName, &aesFirstName, &aesLastName, &aesEmail, &user.ProfilePicLink, &user.Role); err != nil {
		return nil, fmt.Errorf("error reading user information: %v", err)
	}

	user.FirstName, err = aesDecrypt(c, aesFirstName)
	if err != nil {
		return nil, fmt.Errorf("error decrypting first name: %v", err)
	}

	user.LastName, err = aesDecrypt(c, aesLastName)
	if err != nil {
		return nil, fmt.Errorf("error decrypting last name: %v", err)
	}

	user.Email, err = aesDecrypt(c, aesEmail)
	if err != nil {
		return nil, fmt.Errorf("error decrypting email: %v", err)
	}

	return user, nil
}

func (db *DB) UpdateOwnUserAttributes(c *Config, id int64, userName, firstName, lastName, email, profilePicLink, oldPass, newPass string) error {
	var err error
	tx := new(Tx)
	passwordEmpty := len(newPass) == 0

	for i := 0; i < TxMaxRetries; i++ {
		err := func() error {
			tx.Tx, err = db.Begin()
			if err != nil {
				return fmt.Errorf("error starting transaction: %v", err)
			}

			if !passwordEmpty {
				if err = tx.ChangePassword(id, oldPass, newPass); err != nil {
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
						log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
					}
					return fmt.Errorf("error changing password: %v", err)
				}
			}

			aesFirstName, err := aesEncrypt(c, firstName)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error encrypting first name: %v", err)
			}

			aesLastName, err := aesEncrypt(c, lastName)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error encrypting last name: %v", err)
			}

			aesEmail, err := aesEncrypt(c, email)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error encrypting email: %v", err)
			}

			if err = tx.UpdateAttributes(
				&Attribute{Table: "users", ID: id, AttName: "username", Value: userName},
				&Attribute{Table: "users", ID: id, AttName: "first_name", Value: aesFirstName},
				&Attribute{Table: "users", ID: id, AttName: "last_name", Value: aesLastName},
				&Attribute{Table: "users", ID: id, AttName: "email", Value: aesEmail},
				&Attribute{Table: "users", ID: id, AttName: "profile_pic_link", Value: profilePicLink},
			); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error updating attributes in DB: %v", err)
			}

			if err = tx.Commit(); err != nil {
				return fmt.Errorf("error committing transaction: %v", err)
			}

			return nil
		}()
		if err == nil {
			return nil
		}

		log.Println(err)
		wait(i)
	}

	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) AddFirstUser(c *Config, u *User, pass string) (int64, error) {
	var numUsers int64
	txOptions := &sql.TxOptions{Isolation: sql.LevelSerializable}
	selectQuery := "SELECT COUNT(*) FROM users"
	insertQuery := `
    INSERT INTO users (username, password, first_name, last_name, email, profile_pic_link, role)
    VALUES (?, ?, ?, ?, ?, ?, ?)
    `

	for i := 0; i < TxMaxRetries; i++ {
		id, err := func() (int64, error) {
			tx, err := db.BeginTx(context.Background(), txOptions)
			if err != nil {
				return 0, fmt.Errorf("error starting transaction: %v", err)
			}

			if err := tx.QueryRow(selectQuery).Scan(&numUsers); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error getting ID of %v: %v", u.UserName, err)
			}
			if numUsers != 0 {
				if err = tx.Commit(); err != nil {
					return 0, fmt.Errorf("error committing transaction: %v", err)
				}
				return -1, nil
			}

			hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error creating password hash: %v", err)
			}

			aesFirstName, err := aesEncrypt(c, u.FirstName)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error encrypting first name: %v", err)
			}

			aesLastName, err := aesEncrypt(c, u.LastName)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error encrypting last name: %v", err)
			}

			aesEmail, err := aesEncrypt(c, u.Email)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error encrypting email: %v", err)
			}

			result, err := tx.Exec(insertQuery, u.UserName, string(hashedPass), aesFirstName, aesLastName, aesEmail, u.ProfilePicLink, u.Role)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error inserting new user %v into DB: %v", u.UserName, err)
			}

			id, err := result.LastInsertId()
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return 0, fmt.Errorf("error inserting user into DB: %v", err)
			}

			if err = tx.Commit(); err != nil {
				return 0, fmt.Errorf("error committing transaction: %v", err)
			}
			return id, nil
		}()
		if err == nil {
			return id, nil
		}

		log.Println(err)
		wait(i)
	}
	return 0, fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) GetAllUsers(c *Config) ([]*User, error) {
	var aesFirstName, aesLastName, aesEmail string
	var err error

	query := "SELECT id, username, first_name, last_name, email, profile_pic_link, role FROM users"

	rows, err := db.Query(query)
	if err != nil {
		return nil, fmt.Errorf("error getting all users from DB: %v", err)
	}

	users := make([]*User, 0)
	for rows.Next() {
		user := new(User)
		if err = rows.Scan(&user.ID, &user.UserName, &aesFirstName, &aesLastName, &aesEmail, &user.ProfilePicLink, &user.Role); err != nil {
			return nil, fmt.Errorf("error getting user info: %v", err)
		}

		user.FirstName, err = aesDecrypt(c, aesFirstName)
		if err != nil {
			return nil, fmt.Errorf("error decrypting first name: %v", err)
		}

		user.LastName, err = aesDecrypt(c, aesLastName)
		if err != nil {
			return nil, fmt.Errorf("error decrypting last name: %v", err)
		}

		user.Email, err = aesDecrypt(c, aesEmail)
		if err != nil {
			return nil, fmt.Errorf("error decrypting email: %v", err)
		}

		users = append(users, user)
	}

	return users, nil
}

func (db *DB) GetAllUsersMap(c *Config) (map[int64]*User, error) {
	var aesFirstName, aesLastName, aesEmail string
	var err error

	query := "SELECT id, username, first_name, last_name, email, profile_pic_link, role FROM users"

	rows, err := db.Query(query)
	if err != nil {
		return nil, fmt.Errorf("error getting all users from DB: %v", err)
	}

	users := make(map[int64]*User, 0)
	for rows.Next() {
		user := new(User)
		if err = rows.Scan(&user.ID, &user.UserName, &aesFirstName, &aesLastName, &aesEmail, &user.ProfilePicLink, &user.Role); err != nil {
			return nil, fmt.Errorf("error getting user info: %v", err)
		}

		user.FirstName, err = aesDecrypt(c, aesFirstName)
		if err != nil {
			return nil, fmt.Errorf("error decrypting first name: %v", err)
		}

		user.LastName, err = aesDecrypt(c, aesLastName)
		if err != nil {
			return nil, fmt.Errorf("error decrypting last name: %v", err)
		}

		user.Email, err = aesDecrypt(c, aesEmail)
		if err != nil {
			return nil, fmt.Errorf("error decrypting email: %v", err)
		}

		users[user.ID] = user
	}

	return users, nil
}

func (tx *Tx) SetPassword(id int64, newPass string) error {
	hashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
	if err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error creating password hash: %v", err)
	}

	setQuery := "UPDATE users SET password = ? WHERE id = ?"
	if _, err = tx.Exec(setQuery, string(hashedPass), id); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
		}
		return fmt.Errorf("error updating password in DB: %v", err)
	}

	return nil
}

func (db *DB) UpdateUserAttributes(c *Config, id int64, userName, firstName, lastName, email, profilePicLink, newPass string, role int) error {
	var err error
	tx := new(Tx)
	passwordEmpty := len(newPass) == 0

	for i := 0; i < TxMaxRetries; i++ {
		err := func() error {
			tx.Tx, err = db.Begin()
			if err != nil {
				return fmt.Errorf("error starting transaction: %v", err)
			}

			if !passwordEmpty {
				if err = tx.SetPassword(id, newPass); err != nil {
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
						log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
					}
					return fmt.Errorf("error changing password: %v", err)
				}
			}

			aesFirstName, err := aesEncrypt(c, firstName)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error encrypting first name: %v", err)
			}

			aesLastName, err := aesEncrypt(c, lastName)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error encrypting last name: %v", err)
			}

			aesEmail, err := aesEncrypt(c, email)
			if err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error encrypting email: %v", err)
			}

			if err = tx.UpdateAttributes(
				&Attribute{Table: "users", ID: id, AttName: "username", Value: userName},
				&Attribute{Table: "users", ID: id, AttName: "first_name", Value: aesFirstName},
				&Attribute{Table: "users", ID: id, AttName: "last_name", Value: aesLastName},
				&Attribute{Table: "users", ID: id, AttName: "email", Value: aesEmail},
				&Attribute{Table: "users", ID: id, AttName: "profile_pic_link", Value: profilePicLink},
				&Attribute{Table: "users", ID: id, AttName: "role", Value: role},
			); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
				}
				return fmt.Errorf("error updating attributes in DB: %v", err)
			}

			if err = tx.Commit(); err != nil {
				return fmt.Errorf("error committing transaction: %v", err)
			}

			return nil
		}()
		if err == nil {
			return nil
		}

		log.Println(err)
		wait(i)
	}

	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) DeleteUser(id int64) error {
	query := "DELETE FROM users WHERE id = ?"

	_, err := db.Exec(query, id)
	if err != nil {
		return fmt.Errorf("error deleting user %v from DB: %v", id, err)
	}

	return nil
}