package backend

import (
	"bufio"
	"database/sql"
	"fmt"
	"log"
	"math"
	"math/rand/v2"
	"os"
	"strings"
	"syscall"
	"time"

	"github.com/go-sql-driver/mysql"
	"golang.org/x/term"
)

var TxMaxRetries = 5

type (
	DB struct{ *sql.DB }
	Tx struct{ *sql.Tx }

	Attribute struct {
		Value   any
		Table   string
		AttName string
		ID      int64
	}
)

func getUsername() (string, error) {
	user := os.Getenv("DB_USER")
	if user == "" {
		var err error
		fmt.Printf("DB Benutzer: ")
		user, err = bufio.NewReader(os.Stdin).ReadString('\n')
		if err != nil {
			return "", fmt.Errorf("error reading username: %v", err)
		}
	}
	return strings.TrimSpace(user), nil
}

func getPassword() (string, error) {
	pass := os.Getenv("DB_PASS")
	if pass == "" {
		fmt.Printf("DB Passwort: ")
		bytePass, err := term.ReadPassword(int(syscall.Stdin))
		if err != nil {
			return "", fmt.Errorf("error reading password: %v", err)
		}
		fmt.Println()
		pass = strings.TrimSpace(string(bytePass))
	}
	return pass, nil
}

func getCredentials() (string, string, error) {
	user, err := getUsername()
	if err != nil {
		return "", "", fmt.Errorf("error getting username: %v", err)
	}

	pass, err := getPassword()
	if err != nil {
		return "", "", fmt.Errorf("error getting password: %v", err)
	}

	return user, pass, nil
}

func wait(iteration int) {
	waitTime := time.Duration(math.Pow(2, float64(iteration))) * 100 * time.Millisecond
	jitter := time.Duration(rand.IntN(int(waitTime)/2)) * time.Millisecond
	time.Sleep(waitTime + jitter)
}

func OpenDB(dbName string) (*DB, error) {
	var err error
	db := DB{DB: new(sql.DB)}

	cfg := mysql.NewConfig()
	cfg.DBName = dbName
	cfg.User, cfg.Passwd, err = getCredentials()
	if err != nil {
		return nil, fmt.Errorf("error reading user credentials for DB: %v", err)
	}

	db.DB, err = sql.Open("mysql", cfg.FormatDSN())
	if err != nil {
		return nil, fmt.Errorf("error opening DB: %v", err)
	}
	if err = db.Ping(); err != nil {
		return nil, fmt.Errorf("error pinging DB: %v", err)
	}

	return &db, nil
}

func (db *DB) UpdateAttributes(a ...*Attribute) error {
	for i := 0; i < TxMaxRetries; i++ {
		err := func() error {
			tx, err := db.Begin()
			if err != nil {
				return fmt.Errorf("error starting transaction: %v", err)
			}

			for _, attribute := range a {
				query := fmt.Sprintf(`
                    UPDATE %s
                    SET %s = ?
                    WHERE id = ?
                    `, attribute.Table, attribute.AttName)
				if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
						log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
					}
					return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
				}
			}

			if err = tx.Commit(); err != nil {
				return fmt.Errorf("error committing transaction: %v", err)
			}
			return nil
		}()
		if err == nil {
			return nil
		}

		log.Println(err)
		wait(i)
	}
	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
}

func (db *DB) CountEntries(table string) (int64, error) {
	var count int64

	query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
	row := db.QueryRow(query)
	if err := row.Scan(&count); err != nil {
		return 0, fmt.Errorf("error counting rows in user DB: %v", err)
	}

	return count, nil
}

func (tx *Tx) UpdateAttributes(a ...*Attribute) error {
	for _, attribute := range a {
		query := fmt.Sprintf(`
            UPDATE %s
            SET %s = ?
            WHERE id = ?
            `, attribute.Table, attribute.AttName)
		if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
				log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
			}
			return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
		}
	}

	return nil
}