168 lines
3.8 KiB
Go
168 lines
3.8 KiB
Go
package model
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"math/rand/v2"
|
|
"os"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
var TxMaxRetries = 5
|
|
|
|
type (
|
|
DB struct{ *sql.DB }
|
|
Tx struct{ *sql.Tx }
|
|
|
|
Attribute struct {
|
|
Value any
|
|
Table string
|
|
AttName string
|
|
ID int64
|
|
}
|
|
)
|
|
|
|
func getUsername() (string, error) {
|
|
user := os.Getenv("DB_USER")
|
|
if user == "" {
|
|
var err error
|
|
fmt.Printf("DB Benutzer: ")
|
|
user, err = bufio.NewReader(os.Stdin).ReadString('\n')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading username: %v", err)
|
|
}
|
|
}
|
|
return strings.TrimSpace(user), nil
|
|
}
|
|
|
|
func getPassword() (string, error) {
|
|
pass := os.Getenv("DB_PASS")
|
|
if pass == "" {
|
|
fmt.Printf("DB Passwort: ")
|
|
bytePass, err := term.ReadPassword(int(syscall.Stdin))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading password: %v", err)
|
|
}
|
|
fmt.Println()
|
|
pass = strings.TrimSpace(string(bytePass))
|
|
}
|
|
return pass, nil
|
|
}
|
|
|
|
func getCredentials() (string, string, error) {
|
|
user, err := getUsername()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error getting username: %v", err)
|
|
}
|
|
|
|
pass, err := getPassword()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error getting password: %v", err)
|
|
}
|
|
|
|
return user, pass, nil
|
|
}
|
|
|
|
func wait(iteration int) {
|
|
waitTime := time.Duration(math.Pow(2, float64(iteration))) * 100 * time.Millisecond
|
|
jitter := time.Duration(rand.IntN(int(waitTime)/2)) * time.Millisecond
|
|
time.Sleep(waitTime + jitter)
|
|
}
|
|
|
|
func OpenDB(dbName string) (*DB, error) {
|
|
var err error
|
|
db := DB{DB: new(sql.DB)}
|
|
|
|
cfg := mysql.NewConfig()
|
|
cfg.DBName = dbName
|
|
cfg.User, cfg.Passwd, err = getCredentials()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading user credentials for DB: %v", err)
|
|
}
|
|
|
|
db.DB, err = sql.Open("mysql", cfg.FormatDSN())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error opening DB: %v", err)
|
|
}
|
|
if err = db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("error pinging DB: %v", err)
|
|
}
|
|
|
|
return &db, nil
|
|
}
|
|
|
|
func (db *DB) UpdateAttributes(a ...*Attribute) error {
|
|
for i := 0; i < TxMaxRetries; i++ {
|
|
err := func() error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("error starting transaction: %v", err)
|
|
}
|
|
|
|
for _, attribute := range a {
|
|
query := fmt.Sprintf(`
|
|
UPDATE %s
|
|
SET %s = ?
|
|
WHERE id = ?
|
|
`, attribute.Table, attribute.AttName)
|
|
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
}
|
|
return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("error committing transaction: %v", err)
|
|
}
|
|
return nil
|
|
}()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
log.Println(err)
|
|
wait(i)
|
|
}
|
|
return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
|
|
}
|
|
|
|
func (db *DB) CountEntries(table string) (int64, error) {
|
|
var count int64
|
|
|
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
|
row := db.QueryRow(query)
|
|
if err := row.Scan(&count); err != nil {
|
|
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (tx *Tx) UpdateAttributes(a ...*Attribute) error {
|
|
for _, attribute := range a {
|
|
query := fmt.Sprintf(`
|
|
UPDATE %s
|
|
SET %s = ?
|
|
WHERE id = ?
|
|
`, attribute.Table, attribute.AttName)
|
|
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
}
|
|
return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|