package backend import ( "bufio" "database/sql" "fmt" "log" "math" "math/rand/v2" "os" "strings" "syscall" "time" "github.com/go-sql-driver/mysql" "golang.org/x/term" ) var TxMaxRetries = 5 type ( DB struct{ *sql.DB } Tx struct{ *sql.Tx } Attribute struct { Value any Table string AttName string ID int64 } ) func getUsername() (string, error) { user := os.Getenv("DB_USER") if user == "" { var err error fmt.Printf("DB Benutzer: ") user, err = bufio.NewReader(os.Stdin).ReadString('\n') if err != nil { return "", fmt.Errorf("error reading username: %v", err) } } return strings.TrimSpace(user), nil } func getPassword() (string, error) { pass := os.Getenv("DB_PASS") if pass == "" { fmt.Printf("DB Passwort: ") bytePass, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return "", fmt.Errorf("error reading password: %v", err) } fmt.Println() pass = strings.TrimSpace(string(bytePass)) } return pass, nil } func getCredentials() (string, string, error) { user, err := getUsername() if err != nil { return "", "", fmt.Errorf("error getting username: %v", err) } pass, err := getPassword() if err != nil { return "", "", fmt.Errorf("error getting password: %v", err) } return user, pass, nil } func wait(iteration int) { waitTime := time.Duration(math.Pow(2, float64(iteration))) * 100 * time.Millisecond jitter := time.Duration(rand.IntN(int(waitTime)/2)) * time.Millisecond time.Sleep(waitTime + jitter) } func OpenDB(dbName string) (*DB, error) { var err error db := DB{DB: new(sql.DB)} cfg := mysql.NewConfig() cfg.DBName = dbName cfg.User, cfg.Passwd, err = getCredentials() if err != nil { return nil, fmt.Errorf("error reading user credentials for DB: %v", err) } db.DB, err = sql.Open("mysql", cfg.FormatDSN()) if err != nil { return nil, fmt.Errorf("error opening DB: %v", err) } if err = db.Ping(); err != nil { return nil, fmt.Errorf("error pinging DB: %v", err) } return &db, nil } func (db *DB) UpdateAttributes(a ...*Attribute) error { for i := 0; i < TxMaxRetries; i++ { err := func() error { tx, err := db.Begin() if err != nil { return fmt.Errorf("error starting transaction: %v", err) } for _, attribute := range a { query := fmt.Sprintf(` UPDATE %s SET %s = ? WHERE id = ? `, attribute.Table, attribute.AttName) if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr) } return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err) } } if err = tx.Commit(); err != nil { return fmt.Errorf("error committing transaction: %v", err) } return nil }() if err == nil { return nil } log.Println(err) wait(i) } return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries) } func (db *DB) CountEntries(table string) (int64, error) { var count int64 query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) row := db.QueryRow(query) if err := row.Scan(&count); err != nil { return 0, fmt.Errorf("error counting rows in user DB: %v", err) } return count, nil } func (tx *Tx) UpdateAttributes(a ...*Attribute) error { for _, attribute := range a { query := fmt.Sprintf(` UPDATE %s SET %s = ? WHERE id = ? `, attribute.Table, attribute.AttName) if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr) } return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err) } } return nil }