2024-03-09 10:25:20 +01:00
|
|
|
package model
|
2024-02-18 16:37:13 +01:00
|
|
|
|
|
|
|
import (
|
2024-03-09 11:06:03 +01:00
|
|
|
"bufio"
|
2024-02-18 16:37:13 +01:00
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
2024-03-10 15:03:46 +01:00
|
|
|
"log"
|
2024-03-15 15:18:02 +01:00
|
|
|
"math"
|
|
|
|
"math/rand/v2"
|
2024-03-09 11:06:03 +01:00
|
|
|
"os"
|
|
|
|
"strings"
|
|
|
|
"syscall"
|
2024-03-15 15:18:02 +01:00
|
|
|
"time"
|
2024-02-18 16:37:13 +01:00
|
|
|
|
|
|
|
"github.com/go-sql-driver/mysql"
|
2024-03-09 11:06:03 +01:00
|
|
|
"golang.org/x/term"
|
2024-02-18 16:37:13 +01:00
|
|
|
)
|
|
|
|
|
2024-03-15 15:18:02 +01:00
|
|
|
var TxMaxRetries = 3
|
|
|
|
|
2024-02-18 16:37:13 +01:00
|
|
|
type DB struct {
|
|
|
|
*sql.DB
|
|
|
|
}
|
|
|
|
|
2024-03-11 21:08:27 +01:00
|
|
|
type Tx struct {
|
|
|
|
*sql.Tx
|
|
|
|
}
|
|
|
|
|
|
|
|
type Attribute struct {
|
|
|
|
Value interface{}
|
|
|
|
Table string
|
|
|
|
AttName string
|
|
|
|
ID int64
|
|
|
|
}
|
|
|
|
|
2024-03-09 11:06:03 +01:00
|
|
|
func getUsername() (string, error) {
|
|
|
|
user := os.Getenv("DB_USER")
|
|
|
|
if user == "" {
|
|
|
|
var err error
|
|
|
|
fmt.Printf("DB Benutzer: ")
|
|
|
|
user, err = bufio.NewReader(os.Stdin).ReadString('\n')
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("error reading username: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return strings.TrimSpace(user), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func getPassword() (string, error) {
|
|
|
|
pass := os.Getenv("DB_PASS")
|
|
|
|
if pass == "" {
|
|
|
|
fmt.Printf("DB Passwort: ")
|
|
|
|
bytePass, err := term.ReadPassword(int(syscall.Stdin))
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("error reading password: %v", err)
|
|
|
|
}
|
|
|
|
fmt.Println()
|
|
|
|
pass = strings.TrimSpace(string(bytePass))
|
|
|
|
}
|
|
|
|
return pass, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func getCredentials() (string, string, error) {
|
|
|
|
user, err := getUsername()
|
|
|
|
if err != nil {
|
|
|
|
return "", "", fmt.Errorf("error getting username: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
pass, err := getPassword()
|
|
|
|
if err != nil {
|
|
|
|
return "", "", fmt.Errorf("error getting password: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return user, pass, nil
|
|
|
|
}
|
|
|
|
|
2024-02-18 16:37:13 +01:00
|
|
|
func OpenDB(dbName string) (*DB, error) {
|
|
|
|
var err error
|
|
|
|
db := DB{DB: &sql.DB{}}
|
|
|
|
|
|
|
|
cfg := mysql.NewConfig()
|
|
|
|
cfg.DBName = dbName
|
|
|
|
cfg.User, cfg.Passwd, err = getCredentials()
|
|
|
|
if err != nil {
|
2024-02-22 15:23:29 +01:00
|
|
|
return nil, fmt.Errorf("error reading user credentials for DB: %v", err)
|
2024-02-18 16:37:13 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
db.DB, err = sql.Open("mysql", cfg.FormatDSN())
|
|
|
|
if err != nil {
|
2024-02-22 15:23:29 +01:00
|
|
|
return nil, fmt.Errorf("error opening DB: %v", err)
|
2024-02-18 16:37:13 +01:00
|
|
|
}
|
|
|
|
if err = db.Ping(); err != nil {
|
2024-02-22 15:23:29 +01:00
|
|
|
return nil, fmt.Errorf("error pinging DB: %v", err)
|
2024-02-18 16:37:13 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return &db, nil
|
|
|
|
}
|
2024-02-22 18:49:51 +01:00
|
|
|
|
2024-03-10 15:03:46 +01:00
|
|
|
func (db *DB) UpdateAttributes(a ...*Attribute) error {
|
2024-03-15 15:18:02 +01:00
|
|
|
for i := 0; i < TxMaxRetries; i++ {
|
|
|
|
err := func() error {
|
|
|
|
tx, err := db.Begin()
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error starting transaction: %v", err)
|
|
|
|
}
|
2024-03-10 15:03:46 +01:00
|
|
|
|
2024-03-15 15:18:02 +01:00
|
|
|
for _, attribute := range a {
|
|
|
|
query := fmt.Sprintf(`
|
|
|
|
UPDATE %s
|
|
|
|
SET %s = ?
|
|
|
|
WHERE id = ?
|
|
|
|
`, attribute.Table, attribute.AttName)
|
|
|
|
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
|
|
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
|
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
|
|
}
|
|
|
|
return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
|
|
|
|
}
|
2024-03-10 15:03:46 +01:00
|
|
|
}
|
2024-03-15 15:18:02 +01:00
|
|
|
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
|
|
return fmt.Errorf("error committing transaction: %v", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}()
|
|
|
|
if err == nil {
|
|
|
|
return nil
|
2024-03-10 15:03:46 +01:00
|
|
|
}
|
2024-03-15 15:18:02 +01:00
|
|
|
log.Println(err)
|
2024-03-10 15:03:46 +01:00
|
|
|
|
2024-03-15 15:18:02 +01:00
|
|
|
waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second
|
|
|
|
jitter := time.Duration(rand.IntN(1000)) * time.Millisecond
|
|
|
|
time.Sleep(waitTime + jitter)
|
2024-03-07 15:31:00 +01:00
|
|
|
}
|
2024-03-15 15:18:02 +01:00
|
|
|
return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
|
2024-03-07 15:31:00 +01:00
|
|
|
}
|
|
|
|
|
2024-03-09 11:06:03 +01:00
|
|
|
func (db *DB) CountEntries(table string) (int64, error) {
|
2024-02-24 14:49:29 +01:00
|
|
|
var count int64
|
|
|
|
|
2024-03-09 11:06:03 +01:00
|
|
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
2024-02-24 14:49:29 +01:00
|
|
|
row := db.QueryRow(query)
|
|
|
|
if err := row.Scan(&count); err != nil {
|
|
|
|
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return count, nil
|
|
|
|
}
|
2024-03-11 21:08:27 +01:00
|
|
|
|
|
|
|
func (db *DB) StartTransaction() (*Tx, error) {
|
|
|
|
tx := &Tx{Tx: new(sql.Tx)}
|
|
|
|
var err error
|
|
|
|
|
|
|
|
tx.Tx, err = db.Begin()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error starting transaction: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return tx, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx *Tx) CommitTransaction() error {
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
|
|
return fmt.Errorf("error committing transaction: %v", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx *Tx) RollbackTransaction() {
|
|
|
|
if err := tx.Rollback(); err != nil {
|
|
|
|
log.Fatalf("error rolling back transaction: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx *Tx) UpdateAttributes(a ...*Attribute) error {
|
|
|
|
for _, attribute := range a {
|
|
|
|
query := fmt.Sprintf(`
|
|
|
|
UPDATE %s
|
|
|
|
SET %s = ?
|
|
|
|
WHERE id = ?
|
|
|
|
`, attribute.Table, attribute.AttName)
|
|
|
|
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
|
|
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
|
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
|
|
}
|
2024-03-12 19:56:22 +01:00
|
|
|
return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
|
2024-03-11 21:08:27 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|