105 lines
2.2 KiB
Go
105 lines
2.2 KiB
Go
package model
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
type DB struct {
|
|
*sql.DB
|
|
}
|
|
|
|
func getUsername() (string, error) {
|
|
user := os.Getenv("DB_USER")
|
|
if user == "" {
|
|
var err error
|
|
fmt.Printf("DB Benutzer: ")
|
|
user, err = bufio.NewReader(os.Stdin).ReadString('\n')
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading username: %v", err)
|
|
}
|
|
}
|
|
return strings.TrimSpace(user), nil
|
|
}
|
|
|
|
func getPassword() (string, error) {
|
|
pass := os.Getenv("DB_PASS")
|
|
if pass == "" {
|
|
fmt.Printf("DB Passwort: ")
|
|
bytePass, err := term.ReadPassword(int(syscall.Stdin))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading password: %v", err)
|
|
}
|
|
fmt.Println()
|
|
pass = strings.TrimSpace(string(bytePass))
|
|
}
|
|
return pass, nil
|
|
}
|
|
|
|
func getCredentials() (string, string, error) {
|
|
user, err := getUsername()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error getting username: %v", err)
|
|
}
|
|
|
|
pass, err := getPassword()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error getting password: %v", err)
|
|
}
|
|
|
|
return user, pass, nil
|
|
}
|
|
|
|
func OpenDB(dbName string) (*DB, error) {
|
|
var err error
|
|
db := DB{DB: &sql.DB{}}
|
|
|
|
cfg := mysql.NewConfig()
|
|
cfg.DBName = dbName
|
|
cfg.User, cfg.Passwd, err = getCredentials()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading user credentials for DB: %v", err)
|
|
}
|
|
|
|
db.DB, err = sql.Open("mysql", cfg.FormatDSN())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error opening DB: %v", err)
|
|
}
|
|
if err = db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("error pinging DB: %v", err)
|
|
}
|
|
|
|
return &db, nil
|
|
}
|
|
|
|
func (db *DB) UpdateAttribute(table string, id int64, attribute string, val interface{}) error {
|
|
query := fmt.Sprintf(`
|
|
UPDATE %s
|
|
SET %s = ?
|
|
WHERE id = ?
|
|
`, table, attribute)
|
|
if _, err := db.Exec(query, val, id); err != nil {
|
|
return fmt.Errorf("error updating article in DB: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) CountEntries(table string) (int64, error) {
|
|
var count int64
|
|
|
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
|
row := db.QueryRow(query)
|
|
if err := row.Scan(&count); err != nil {
|
|
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|