package model import ( "bufio" "database/sql" "fmt" "log" "os" "strings" "syscall" "github.com/go-sql-driver/mysql" "golang.org/x/term" ) type DB struct { *sql.DB } type Tx struct { *sql.Tx } type Attribute struct { Value interface{} Table string AttName string ID int64 } func getUsername() (string, error) { user := os.Getenv("DB_USER") if user == "" { var err error fmt.Printf("DB Benutzer: ") user, err = bufio.NewReader(os.Stdin).ReadString('\n') if err != nil { return "", fmt.Errorf("error reading username: %v", err) } } return strings.TrimSpace(user), nil } func getPassword() (string, error) { pass := os.Getenv("DB_PASS") if pass == "" { fmt.Printf("DB Passwort: ") bytePass, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return "", fmt.Errorf("error reading password: %v", err) } fmt.Println() pass = strings.TrimSpace(string(bytePass)) } return pass, nil } func getCredentials() (string, string, error) { user, err := getUsername() if err != nil { return "", "", fmt.Errorf("error getting username: %v", err) } pass, err := getPassword() if err != nil { return "", "", fmt.Errorf("error getting password: %v", err) } return user, pass, nil } func OpenDB(dbName string) (*DB, error) { var err error db := DB{DB: &sql.DB{}} cfg := mysql.NewConfig() cfg.DBName = dbName cfg.User, cfg.Passwd, err = getCredentials() if err != nil { return nil, fmt.Errorf("error reading user credentials for DB: %v", err) } db.DB, err = sql.Open("mysql", cfg.FormatDSN()) if err != nil { return nil, fmt.Errorf("error opening DB: %v", err) } if err = db.Ping(); err != nil { return nil, fmt.Errorf("error pinging DB: %v", err) } return &db, nil } func (db *DB) UpdateAttributes(a ...*Attribute) error { tx, err := db.Begin() if err != nil { return fmt.Errorf("error starting transaction: %v", err) } for _, attribute := range a { query := fmt.Sprintf(` UPDATE %s SET %s = ? WHERE id = ? `, attribute.Table, attribute.AttName) if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) } return fmt.Errorf("error updating article in DB: %v", err) } } if err = tx.Commit(); err != nil { return fmt.Errorf("error committing transaction: %v", err) } return nil } func (db *DB) CountEntries(table string) (int64, error) { var count int64 query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) row := db.QueryRow(query) if err := row.Scan(&count); err != nil { return 0, fmt.Errorf("error counting rows in user DB: %v", err) } return count, nil } func (db *DB) StartTransaction() (*Tx, error) { tx := &Tx{Tx: new(sql.Tx)} var err error tx.Tx, err = db.Begin() if err != nil { return nil, fmt.Errorf("error starting transaction: %v", err) } return tx, nil } func (tx *Tx) CommitTransaction() error { if err := tx.Commit(); err != nil { return fmt.Errorf("error committing transaction: %v", err) } return nil } func (tx *Tx) RollbackTransaction() { if err := tx.Rollback(); err != nil { log.Fatalf("error rolling back transaction: %v", err) } } func (tx *Tx) UpdateAttributes(a ...*Attribute) error { for _, attribute := range a { query := fmt.Sprintf(` UPDATE %s SET %s = ? WHERE id = ? `, attribute.Table, attribute.AttName) if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) } return fmt.Errorf("error updating article in DB: %v", err) } } return nil }