From a1a6b6c29f2cd97a37ade0cefaa3467a34554b58 Mon Sep 17 00:00:00 2001 From: Jason Streifling Date: Sat, 9 Mar 2024 11:06:03 +0100 Subject: [PATCH] Split up db.go into multiple files --- cmd/model/articles.go | 85 ++++++++++ cmd/model/articles_tags.go | 56 +++++++ cmd/model/db.go | 334 ++++++------------------------------- cmd/model/helpers.go | 52 ------ cmd/model/tags.go | 30 ++++ cmd/model/users.go | 133 +++++++++++++++ cmd/view/admin.go | 2 +- cmd/view/sessions.go | 2 +- 8 files changed, 354 insertions(+), 340 deletions(-) create mode 100644 cmd/model/articles.go create mode 100644 cmd/model/articles_tags.go delete mode 100644 cmd/model/helpers.go create mode 100644 cmd/model/tags.go create mode 100644 cmd/model/users.go diff --git a/cmd/model/articles.go b/cmd/model/articles.go new file mode 100644 index 0000000..4f0fe18 --- /dev/null +++ b/cmd/model/articles.go @@ -0,0 +1,85 @@ +package model + +import ( + "fmt" + "time" +) + +func (db *DB) AddArticle(a *Article) (int64, error) { + query := ` + INSERT INTO articles + (title, description, content, published, author_id) + VALUES + (?, ?, ?, ?, ?) + ` + + result, err := db.Exec(query, a.Title, a.Desc, a.Content, a.Published, a.AuthorID) + if err != nil { + return 0, fmt.Errorf("error inserting article into DB: %v", err) + } + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("error retrieving last ID: %v", err) + } + + return id, nil +} + +func (db *DB) GetArticle(id int64) (*Article, error) { + query := ` + SELECT title, created, description, content, published, author_id + FROM articles + WHERE id = ? + ` + row := db.QueryRow(query, id) + + article := new(Article) + var created []byte + var err error + + if err := row.Scan(&article.Title, &created, &article.Desc, + &article.Content, &article.Published, &article.AuthorID); err != nil { + return nil, fmt.Errorf("error scanning article row: %v", err) + } + + article.ID = id + article.Created, err = time.Parse("2006-01-02 15:04:05", string(created)) + if err != nil { + return nil, fmt.Errorf("error parsing created: %v", err) + } + + return article, nil +} + +func (db *DB) GetCertainArticles(published bool) ([]*Article, error) { + query := ` + SELECT id, title, created, description, content, author_id + FROM articles + WHERE published = ? + ` + rows, err := db.Query(query, published) + if err != nil { + return nil, fmt.Errorf("error querying articles: %v", err) + } + + articleList := make([]*Article, 0) + for rows.Next() { + article := new(Article) + var created []byte + + if err = rows.Scan(&article.ID, &article.Title, &created, &article.Desc, + &article.Content, &article.AuthorID); err != nil { + return nil, fmt.Errorf("error scanning article row: %v", err) + } + + article.Published = false + article.Created, err = time.Parse("2006-01-02 15:04:05", string(created)) + if err != nil { + return nil, fmt.Errorf("error parsing created: %v", err) + } + + articleList = append(articleList, article) + } + + return articleList, nil +} diff --git a/cmd/model/articles_tags.go b/cmd/model/articles_tags.go new file mode 100644 index 0000000..e2cbe12 --- /dev/null +++ b/cmd/model/articles_tags.go @@ -0,0 +1,56 @@ +package model + +import ( + "fmt" + "log" +) + +func (db *DB) WriteArticleTags(articleID int64, tagIDs []int64) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("error starting transaction: %v", err) + } + + for _, tagID := range tagIDs { + query := ` + INSERT INTO articles_tags (article_id, tag_id) + VALUES (?, ?) + ` + if _, err := tx.Exec(query, articleID, tagID); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error inserting into articles_tags: %v", err) + } + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("error committing transaction: %v", err) + } + return nil +} + +func (db *DB) GetArticleTags(articleID int64) ([]*Tag, error) { + query := ` + SELECT t.id, t.name + FROM articles a + INNER JOIN articles_tags at ON a.id = at.article_id + INNER JOIN tags t ON at.tag_id = t.id + WHERE a.id = ? + ` + rows, err := db.Query(query, articleID) + if err != nil { + return nil, fmt.Errorf("error querying articles_tags: %v", err) + } + + tags := make([]*Tag, 0) + for rows.Next() { + tag := new(Tag) + if err = rows.Scan(&tag.ID, &tag.Name); err != nil { + return nil, fmt.Errorf("error scanning rows: %v", err) + } + tags = append(tags, tag) + } + + return tags, nil +} diff --git a/cmd/model/db.go b/cmd/model/db.go index 95b2d2f..7af4ee2 100644 --- a/cmd/model/db.go +++ b/cmd/model/db.go @@ -1,19 +1,62 @@ package model import ( + "bufio" "database/sql" "fmt" - "log" - "time" + "os" + "strings" + "syscall" "github.com/go-sql-driver/mysql" - "golang.org/x/crypto/bcrypt" + "golang.org/x/term" ) type DB struct { *sql.DB } +func getUsername() (string, error) { + user := os.Getenv("DB_USER") + if user == "" { + var err error + fmt.Printf("DB Benutzer: ") + user, err = bufio.NewReader(os.Stdin).ReadString('\n') + if err != nil { + return "", fmt.Errorf("error reading username: %v", err) + } + } + return strings.TrimSpace(user), nil +} + +func getPassword() (string, error) { + pass := os.Getenv("DB_PASS") + if pass == "" { + fmt.Printf("DB Passwort: ") + bytePass, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return "", fmt.Errorf("error reading password: %v", err) + } + fmt.Println() + pass = strings.TrimSpace(string(bytePass)) + } + return pass, nil +} + +func getCredentials() (string, string, error) { + user, err := getUsername() + if err != nil { + return "", "", fmt.Errorf("error getting username: %v", err) + } + + pass, err := getPassword() + if err != nil { + return "", "", fmt.Errorf("error getting password: %v", err) + } + + return user, pass, nil +} + func OpenDB(dbName string) (*DB, error) { var err error db := DB{DB: &sql.DB{}} @@ -48,118 +91,10 @@ func (db *DB) UpdateAttribute(table string, id int64, attribute string, val inte return nil } -func (db *DB) AddUser(user *User, pass string) error { - hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost) - if err != nil { - return fmt.Errorf("error creating password hash: %v", err) - } - - query := ` - INSERT INTO users - (username, password, first_name, last_name, role) - VALUES (?, ?, ?, ?, ?) - ` - if _, err = db.Exec(query, user.UserName, string(hashedPass), user.FirstName, user.LastName, user.Role); err != nil { - return fmt.Errorf("error inserting user into DB: %v", err) - } - - return nil -} - -func (db *DB) GetID(userName string) (int64, error) { - var id int64 - - query := ` - SELECT id - FROM users - WHERE username = ? - ` - row := db.QueryRow(query, userName) - if err := row.Scan(&id); err != nil { - return 0, fmt.Errorf("user not in DB: %v", err) - } - - return id, nil -} - -func (db *DB) CheckPassword(id int64, pass string) error { - var queriedPass string - - query := ` - SELECT password - FROM users - WHERE id = ? - ` - row := db.QueryRow(query, id) - if err := row.Scan(&queriedPass); err != nil { - return fmt.Errorf("error reading password from DB: %v", err) - } - - if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil { - return fmt.Errorf("incorrect password: %v", err) - } - - return nil -} - -func (db *DB) ChangePassword(id int64, oldPass, newPass string) error { - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("error starting transaction: %v", err) - } - - var queriedPass string - getQuery := ` - SELECT password - FROM users - WHERE id = ? - ` - row := tx.QueryRow(getQuery, id) - if err := row.Scan(&queriedPass); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) - } - return fmt.Errorf("error reading password from DB: %v", err) - } - - if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) - } - return fmt.Errorf("incorrect password: %v", err) - } - - newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost) - if err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) - } - return fmt.Errorf("error creating password hash: %v", err) - } - - setQuery := ` - UPDATE users - SET password = ? - WHERE id = ? - ` - if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) - } - return fmt.Errorf("error updating password in DB: %v", err) - } - - if err = tx.Commit(); err != nil { - return fmt.Errorf("error committing transaction: %v", err) - } - - return nil -} - -func (db *DB) CountEntries() (int64, error) { +func (db *DB) CountEntries(table string) (int64, error) { var count int64 - query := `SELECT COUNT(*) FROM users` + query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) row := db.QueryRow(query) if err := row.Scan(&count); err != nil { return 0, fmt.Errorf("error counting rows in user DB: %v", err) @@ -167,176 +102,3 @@ func (db *DB) CountEntries() (int64, error) { return count, nil } - -// TODO: No need for ID field in general -func (db *DB) GetUser(id int64) (*User, error) { - user := new(User) - query := ` - SELECT id, username, first_name, last_name, role - FROM users - WHERE id = ? - ` - - row := db.QueryRow(query, id) - if err := row.Scan(&user.ID, &user.UserName, &user.FirstName, &user.LastName, &user.Role); err != nil { - return nil, fmt.Errorf("error reading user information: %v", err) - } - - return user, nil -} - -func (db *DB) AddTag(tagName string) error { - query := "INSERT INTO tags (name) VALUES (?)" - if _, err := db.Exec(query, tagName); err != nil { - return fmt.Errorf("error inserting tag into DB: %v", err) - } - return nil -} - -func (db *DB) GetTagList() ([]*Tag, error) { - query := "SELECT id, name FROM tags" - rows, err := db.Query(query) - if err != nil { - return nil, fmt.Errorf("error querying tags: %v", err) - } - - tagList := make([]*Tag, 0) - for rows.Next() { - tag := new(Tag) - if err = rows.Scan(&tag.ID, &tag.Name); err != nil { - return nil, fmt.Errorf("error scanning tag row: %v", err) - } - tagList = append(tagList, tag) - } - - return tagList, nil -} - -func (db *DB) AddArticle(a *Article) (int64, error) { - query := ` - INSERT INTO articles - (title, description, content, published, author_id) - VALUES - (?, ?, ?, ?, ?) - ` - - result, err := db.Exec(query, a.Title, a.Desc, a.Content, a.Published, a.AuthorID) - if err != nil { - return 0, fmt.Errorf("error inserting article into DB: %v", err) - } - id, err := result.LastInsertId() - if err != nil { - return 0, fmt.Errorf("error retrieving last ID: %v", err) - } - - return id, nil -} - -func (db *DB) GetArticle(id int64) (*Article, error) { - query := ` - SELECT title, created, description, content, published, author_id - FROM articles - WHERE id = ? - ` - row := db.QueryRow(query, id) - - article := new(Article) - var created []byte - var err error - - if err := row.Scan(&article.Title, &created, &article.Desc, - &article.Content, &article.Published, &article.AuthorID); err != nil { - return nil, fmt.Errorf("error scanning article row: %v", err) - } - - article.ID = id - article.Created, err = time.Parse("2006-01-02 15:04:05", string(created)) - if err != nil { - return nil, fmt.Errorf("error parsing created: %v", err) - } - - return article, nil -} - -func (db *DB) GetCertainArticles(published bool) ([]*Article, error) { - query := ` - SELECT id, title, created, description, content, author_id - FROM articles - WHERE published = ? - ` - rows, err := db.Query(query, published) - if err != nil { - return nil, fmt.Errorf("error querying articles: %v", err) - } - - articleList := make([]*Article, 0) - for rows.Next() { - article := new(Article) - var created []byte - - if err = rows.Scan(&article.ID, &article.Title, &created, &article.Desc, - &article.Content, &article.AuthorID); err != nil { - return nil, fmt.Errorf("error scanning article row: %v", err) - } - - article.Published = false - article.Created, err = time.Parse("2006-01-02 15:04:05", string(created)) - if err != nil { - return nil, fmt.Errorf("error parsing created: %v", err) - } - - articleList = append(articleList, article) - } - - return articleList, nil -} - -func (db *DB) WriteArticleTags(articleID int64, tagIDs []int64) error { - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("error starting transaction: %v", err) - } - - for _, tagID := range tagIDs { - query := ` - INSERT INTO articles_tags (article_id, tag_id) - VALUES (?, ?) - ` - if _, err := tx.Exec(query, articleID, tagID); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) - } - return fmt.Errorf("error inserting into articles_tags: %v", err) - } - } - - if err = tx.Commit(); err != nil { - return fmt.Errorf("error committing transaction: %v", err) - } - return nil -} - -func (db *DB) GetArticleTags(articleID int64) ([]*Tag, error) { - query := ` - SELECT t.id, t.name - FROM articles a - INNER JOIN articles_tags at ON a.id = at.article_id - INNER JOIN tags t ON at.tag_id = t.id - WHERE a.id = ? - ` - rows, err := db.Query(query, articleID) - if err != nil { - return nil, fmt.Errorf("error querying articles_tags: %v", err) - } - - tags := make([]*Tag, 0) - for rows.Next() { - tag := new(Tag) - if err = rows.Scan(&tag.ID, &tag.Name); err != nil { - return nil, fmt.Errorf("error scanning rows: %v", err) - } - tags = append(tags, tag) - } - - return tags, nil -} diff --git a/cmd/model/helpers.go b/cmd/model/helpers.go deleted file mode 100644 index 58fbde4..0000000 --- a/cmd/model/helpers.go +++ /dev/null @@ -1,52 +0,0 @@ -package model - -import ( - "bufio" - "fmt" - "os" - "strings" - "syscall" - - "golang.org/x/term" -) - -func getUsername() (string, error) { - user := os.Getenv("DB_USER") - if user == "" { - var err error - fmt.Printf("DB Benutzer: ") - user, err = bufio.NewReader(os.Stdin).ReadString('\n') - if err != nil { - return "", fmt.Errorf("error reading username: %v", err) - } - } - return strings.TrimSpace(user), nil -} - -func getPassword() (string, error) { - pass := os.Getenv("DB_PASS") - if pass == "" { - fmt.Printf("DB Passwort: ") - bytePass, err := term.ReadPassword(int(syscall.Stdin)) - if err != nil { - return "", fmt.Errorf("error reading password: %v", err) - } - fmt.Println() - pass = strings.TrimSpace(string(bytePass)) - } - return pass, nil -} - -func getCredentials() (string, string, error) { - user, err := getUsername() - if err != nil { - return "", "", fmt.Errorf("error getting username: %v", err) - } - - pass, err := getPassword() - if err != nil { - return "", "", fmt.Errorf("error getting password: %v", err) - } - - return user, pass, nil -} diff --git a/cmd/model/tags.go b/cmd/model/tags.go new file mode 100644 index 0000000..d4a80a9 --- /dev/null +++ b/cmd/model/tags.go @@ -0,0 +1,30 @@ +package model + +import "fmt" + +func (db *DB) AddTag(tagName string) error { + query := "INSERT INTO tags (name) VALUES (?)" + if _, err := db.Exec(query, tagName); err != nil { + return fmt.Errorf("error inserting tag into DB: %v", err) + } + return nil +} + +func (db *DB) GetTagList() ([]*Tag, error) { + query := "SELECT id, name FROM tags" + rows, err := db.Query(query) + if err != nil { + return nil, fmt.Errorf("error querying tags: %v", err) + } + + tagList := make([]*Tag, 0) + for rows.Next() { + tag := new(Tag) + if err = rows.Scan(&tag.ID, &tag.Name); err != nil { + return nil, fmt.Errorf("error scanning tag row: %v", err) + } + tagList = append(tagList, tag) + } + + return tagList, nil +} diff --git a/cmd/model/users.go b/cmd/model/users.go new file mode 100644 index 0000000..dbe5888 --- /dev/null +++ b/cmd/model/users.go @@ -0,0 +1,133 @@ +package model + +import ( + "fmt" + "log" + + "golang.org/x/crypto/bcrypt" +) + +func (db *DB) AddUser(user *User, pass string) error { + hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("error creating password hash: %v", err) + } + + query := ` + INSERT INTO users + (username, password, first_name, last_name, role) + VALUES (?, ?, ?, ?, ?) + ` + if _, err = db.Exec(query, user.UserName, string(hashedPass), user.FirstName, user.LastName, user.Role); err != nil { + return fmt.Errorf("error inserting user into DB: %v", err) + } + + return nil +} + +func (db *DB) GetID(userName string) (int64, error) { + var id int64 + + query := ` + SELECT id + FROM users + WHERE username = ? + ` + row := db.QueryRow(query, userName) + if err := row.Scan(&id); err != nil { + return 0, fmt.Errorf("user not in DB: %v", err) + } + + return id, nil +} + +func (db *DB) CheckPassword(id int64, pass string) error { + var queriedPass string + + query := ` + SELECT password + FROM users + WHERE id = ? + ` + row := db.QueryRow(query, id) + if err := row.Scan(&queriedPass); err != nil { + return fmt.Errorf("error reading password from DB: %v", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil { + return fmt.Errorf("incorrect password: %v", err) + } + + return nil +} + +func (db *DB) ChangePassword(id int64, oldPass, newPass string) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("error starting transaction: %v", err) + } + + var queriedPass string + getQuery := ` + SELECT password + FROM users + WHERE id = ? + ` + row := tx.QueryRow(getQuery, id) + if err := row.Scan(&queriedPass); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error reading password from DB: %v", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("incorrect password: %v", err) + } + + newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error creating password hash: %v", err) + } + + setQuery := ` + UPDATE users + SET password = ? + WHERE id = ? + ` + if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error updating password in DB: %v", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("error committing transaction: %v", err) + } + + return nil +} + +// TODO: No need for ID field in general +func (db *DB) GetUser(id int64) (*User, error) { + user := new(User) + query := ` + SELECT id, username, first_name, last_name, role + FROM users + WHERE id = ? + ` + + row := db.QueryRow(query, id) + if err := row.Scan(&user.ID, &user.UserName, &user.FirstName, &user.LastName, &user.Role); err != nil { + return nil, fmt.Errorf("error reading user information: %v", err) + } + + return user, nil +} diff --git a/cmd/view/admin.go b/cmd/view/admin.go index e79d7c9..b94a222 100644 --- a/cmd/view/admin.go +++ b/cmd/view/admin.go @@ -93,7 +93,7 @@ func AddUser(db *model.DB, s *control.CookieStore) http.HandlerFunc { return } - num, err := db.CountEntries() + num, err := db.CountEntries("users") if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/cmd/view/sessions.go b/cmd/view/sessions.go index 4aad527..a7bfc72 100644 --- a/cmd/view/sessions.go +++ b/cmd/view/sessions.go @@ -29,7 +29,7 @@ func saveSession(w http.ResponseWriter, r *http.Request, s *control.CookieStore, func HomePage(db *model.DB, s *control.CookieStore) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - numRows, err := db.CountEntries() + numRows, err := db.CountEntries("users") if err != nil { log.Fatalln(err) }