forked from jason/cpolis
162 lines
3.6 KiB
Go
162 lines
3.6 KiB
Go
package backend
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"math/rand/v2"
|
|
"os"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
var TxMaxRetries = 5
|
|
|
|
type (
|
|
DB struct{ *sql.DB }
|
|
Tx struct{ *sql.Tx }
|
|
|
|
Attribute struct {
|
|
Value any
|
|
Table string
|
|
AttName string
|
|
ID int64
|
|
}
|
|
)
|
|
|
|
func getDBUsername(c *Config) error {
|
|
if c.DBUser == "" {
|
|
var err error
|
|
|
|
fmt.Printf("DB Benutzer: ")
|
|
c.DBUser, err = bufio.NewReader(os.Stdin).ReadString('\n')
|
|
if err != nil {
|
|
return fmt.Errorf("error reading username: %v", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getDBPassword(c *Config) error {
|
|
if c.DBPass == "" {
|
|
fmt.Printf("DB Passwort: ")
|
|
bytePass, err := term.ReadPassword(int(syscall.Stdin))
|
|
if err != nil {
|
|
return fmt.Errorf("error reading password: %v", err)
|
|
}
|
|
|
|
fmt.Println()
|
|
c.DBPass = strings.TrimSpace(string(bytePass))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func wait(iteration int) {
|
|
waitTime := time.Duration(math.Pow(2, float64(iteration))) * 100 * time.Millisecond
|
|
jitter := time.Duration(rand.IntN(int(waitTime)/2)) * time.Millisecond
|
|
time.Sleep(waitTime + jitter)
|
|
}
|
|
|
|
func OpenDB(c *Config) (*DB, error) {
|
|
var err error
|
|
db := new(DB)
|
|
|
|
if err := getDBUsername(c); err != nil {
|
|
return nil, fmt.Errorf("error getting DB username: %v", err)
|
|
}
|
|
|
|
if err := getDBPassword(c); err != nil {
|
|
return nil, fmt.Errorf("error getting DB password: %v", err)
|
|
}
|
|
|
|
cfg := mysql.NewConfig()
|
|
cfg.DBName = c.DBName
|
|
cfg.User = c.DBUser
|
|
cfg.Passwd = c.DBPass
|
|
|
|
db.DB, err = sql.Open("mysql", cfg.FormatDSN())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error opening DB: %v", err)
|
|
}
|
|
|
|
if err = db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("error pinging DB: %v", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func (db *DB) UpdateAttributes(a ...*Attribute) error {
|
|
for i := 0; i < TxMaxRetries; i++ {
|
|
err := func() error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("error starting transaction: %v", err)
|
|
}
|
|
|
|
for _, attribute := range a {
|
|
query := fmt.Sprintf(`
|
|
UPDATE %s
|
|
SET %s = ?
|
|
WHERE id = ?
|
|
`, attribute.Table, attribute.AttName)
|
|
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
}
|
|
return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("error committing transaction: %v", err)
|
|
}
|
|
return nil
|
|
}()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
log.Println(err)
|
|
wait(i)
|
|
}
|
|
return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
|
|
}
|
|
|
|
func (db *DB) CountEntries(table string) (int64, error) {
|
|
var count int64
|
|
|
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
|
row := db.QueryRow(query)
|
|
if err := row.Scan(&count); err != nil {
|
|
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (tx *Tx) UpdateAttributes(a ...*Attribute) error {
|
|
for _, attribute := range a {
|
|
query := fmt.Sprintf(`
|
|
UPDATE %s
|
|
SET %s = ?
|
|
WHERE id = ?
|
|
`, attribute.Table, attribute.AttName)
|
|
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
log.Fatalf("transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
}
|
|
return fmt.Errorf("error updating %v in DB: %v", attribute.AttName, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|