cpolis/cmd/model/db.go

120 lines
2.6 KiB
Go
Raw Normal View History

2024-03-09 10:25:20 +01:00
package model
import (
2024-03-09 11:06:03 +01:00
"bufio"
"database/sql"
"fmt"
"log"
2024-03-09 11:06:03 +01:00
"os"
"strings"
"syscall"
"github.com/go-sql-driver/mysql"
2024-03-09 11:06:03 +01:00
"golang.org/x/term"
)
type DB struct {
*sql.DB
}
2024-03-09 11:06:03 +01:00
func getUsername() (string, error) {
user := os.Getenv("DB_USER")
if user == "" {
var err error
fmt.Printf("DB Benutzer: ")
user, err = bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", fmt.Errorf("error reading username: %v", err)
}
}
return strings.TrimSpace(user), nil
}
func getPassword() (string, error) {
pass := os.Getenv("DB_PASS")
if pass == "" {
fmt.Printf("DB Passwort: ")
bytePass, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return "", fmt.Errorf("error reading password: %v", err)
}
fmt.Println()
pass = strings.TrimSpace(string(bytePass))
}
return pass, nil
}
func getCredentials() (string, string, error) {
user, err := getUsername()
if err != nil {
return "", "", fmt.Errorf("error getting username: %v", err)
}
pass, err := getPassword()
if err != nil {
return "", "", fmt.Errorf("error getting password: %v", err)
}
return user, pass, nil
}
func OpenDB(dbName string) (*DB, error) {
var err error
db := DB{DB: &sql.DB{}}
cfg := mysql.NewConfig()
cfg.DBName = dbName
cfg.User, cfg.Passwd, err = getCredentials()
if err != nil {
2024-02-22 15:23:29 +01:00
return nil, fmt.Errorf("error reading user credentials for DB: %v", err)
}
db.DB, err = sql.Open("mysql", cfg.FormatDSN())
if err != nil {
2024-02-22 15:23:29 +01:00
return nil, fmt.Errorf("error opening DB: %v", err)
}
if err = db.Ping(); err != nil {
2024-02-22 15:23:29 +01:00
return nil, fmt.Errorf("error pinging DB: %v", err)
}
return &db, nil
}
2024-02-22 18:49:51 +01:00
func (db *DB) UpdateAttributes(a ...*Attribute) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("error starting transaction: %v", err)
}
for _, attribute := range a {
query := fmt.Sprintf(`
2024-03-07 20:11:28 +01:00
UPDATE %s
SET %s = ?
WHERE id = ?
`, attribute.Table, attribute.AttName)
if _, err := tx.Exec(query, attribute.Value, attribute.ID); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
}
return fmt.Errorf("error updating article in DB: %v", err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("error committing transaction: %v", err)
}
return nil
}
2024-03-09 11:06:03 +01:00
func (db *DB) CountEntries(table string) (int64, error) {
2024-02-24 14:49:29 +01:00
var count int64
2024-03-09 11:06:03 +01:00
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
2024-02-24 14:49:29 +01:00
row := db.QueryRow(query)
if err := row.Scan(&count); err != nil {
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
}
return count, nil
}