Split up db.go into multiple files
This commit is contained in:
parent
42596756de
commit
a1a6b6c29f
85
cmd/model/articles.go
Normal file
85
cmd/model/articles.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DB) AddArticle(a *Article) (int64, error) {
|
||||||
|
query := `
|
||||||
|
INSERT INTO articles
|
||||||
|
(title, description, content, published, author_id)
|
||||||
|
VALUES
|
||||||
|
(?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
result, err := db.Exec(query, a.Title, a.Desc, a.Content, a.Published, a.AuthorID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error inserting article into DB: %v", err)
|
||||||
|
}
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error retrieving last ID: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetArticle(id int64) (*Article, error) {
|
||||||
|
query := `
|
||||||
|
SELECT title, created, description, content, published, author_id
|
||||||
|
FROM articles
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
row := db.QueryRow(query, id)
|
||||||
|
|
||||||
|
article := new(Article)
|
||||||
|
var created []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if err := row.Scan(&article.Title, &created, &article.Desc,
|
||||||
|
&article.Content, &article.Published, &article.AuthorID); err != nil {
|
||||||
|
return nil, fmt.Errorf("error scanning article row: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
article.ID = id
|
||||||
|
article.Created, err = time.Parse("2006-01-02 15:04:05", string(created))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return article, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetCertainArticles(published bool) ([]*Article, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, title, created, description, content, author_id
|
||||||
|
FROM articles
|
||||||
|
WHERE published = ?
|
||||||
|
`
|
||||||
|
rows, err := db.Query(query, published)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error querying articles: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
articleList := make([]*Article, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
article := new(Article)
|
||||||
|
var created []byte
|
||||||
|
|
||||||
|
if err = rows.Scan(&article.ID, &article.Title, &created, &article.Desc,
|
||||||
|
&article.Content, &article.AuthorID); err != nil {
|
||||||
|
return nil, fmt.Errorf("error scanning article row: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
article.Published = false
|
||||||
|
article.Created, err = time.Parse("2006-01-02 15:04:05", string(created))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
articleList = append(articleList, article)
|
||||||
|
}
|
||||||
|
|
||||||
|
return articleList, nil
|
||||||
|
}
|
56
cmd/model/articles_tags.go
Normal file
56
cmd/model/articles_tags.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DB) WriteArticleTags(articleID int64, tagIDs []int64) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error starting transaction: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tagID := range tagIDs {
|
||||||
|
query := `
|
||||||
|
INSERT INTO articles_tags (article_id, tag_id)
|
||||||
|
VALUES (?, ?)
|
||||||
|
`
|
||||||
|
if _, err := tx.Exec(query, articleID, tagID); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error inserting into articles_tags: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("error committing transaction: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetArticleTags(articleID int64) ([]*Tag, error) {
|
||||||
|
query := `
|
||||||
|
SELECT t.id, t.name
|
||||||
|
FROM articles a
|
||||||
|
INNER JOIN articles_tags at ON a.id = at.article_id
|
||||||
|
INNER JOIN tags t ON at.tag_id = t.id
|
||||||
|
WHERE a.id = ?
|
||||||
|
`
|
||||||
|
rows, err := db.Query(query, articleID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error querying articles_tags: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tags := make([]*Tag, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
tag := new(Tag)
|
||||||
|
if err = rows.Scan(&tag.ID, &tag.Name); err != nil {
|
||||||
|
return nil, fmt.Errorf("error scanning rows: %v", err)
|
||||||
|
}
|
||||||
|
tags = append(tags, tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags, nil
|
||||||
|
}
|
334
cmd/model/db.go
334
cmd/model/db.go
@ -1,19 +1,62 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"os"
|
||||||
"time"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/go-sql-driver/mysql"
|
"github.com/go-sql-driver/mysql"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*sql.DB
|
*sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getUsername() (string, error) {
|
||||||
|
user := os.Getenv("DB_USER")
|
||||||
|
if user == "" {
|
||||||
|
var err error
|
||||||
|
fmt.Printf("DB Benutzer: ")
|
||||||
|
user, err = bufio.NewReader(os.Stdin).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error reading username: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(user), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPassword() (string, error) {
|
||||||
|
pass := os.Getenv("DB_PASS")
|
||||||
|
if pass == "" {
|
||||||
|
fmt.Printf("DB Passwort: ")
|
||||||
|
bytePass, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error reading password: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
pass = strings.TrimSpace(string(bytePass))
|
||||||
|
}
|
||||||
|
return pass, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCredentials() (string, string, error) {
|
||||||
|
user, err := getUsername()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("error getting username: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pass, err := getPassword()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("error getting password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, pass, nil
|
||||||
|
}
|
||||||
|
|
||||||
func OpenDB(dbName string) (*DB, error) {
|
func OpenDB(dbName string) (*DB, error) {
|
||||||
var err error
|
var err error
|
||||||
db := DB{DB: &sql.DB{}}
|
db := DB{DB: &sql.DB{}}
|
||||||
@ -48,118 +91,10 @@ func (db *DB) UpdateAttribute(table string, id int64, attribute string, val inte
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) AddUser(user *User, pass string) error {
|
func (db *DB) CountEntries(table string) (int64, error) {
|
||||||
hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error creating password hash: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
query := `
|
|
||||||
INSERT INTO users
|
|
||||||
(username, password, first_name, last_name, role)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
`
|
|
||||||
if _, err = db.Exec(query, user.UserName, string(hashedPass), user.FirstName, user.LastName, user.Role); err != nil {
|
|
||||||
return fmt.Errorf("error inserting user into DB: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) GetID(userName string) (int64, error) {
|
|
||||||
var id int64
|
|
||||||
|
|
||||||
query := `
|
|
||||||
SELECT id
|
|
||||||
FROM users
|
|
||||||
WHERE username = ?
|
|
||||||
`
|
|
||||||
row := db.QueryRow(query, userName)
|
|
||||||
if err := row.Scan(&id); err != nil {
|
|
||||||
return 0, fmt.Errorf("user not in DB: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) CheckPassword(id int64, pass string) error {
|
|
||||||
var queriedPass string
|
|
||||||
|
|
||||||
query := `
|
|
||||||
SELECT password
|
|
||||||
FROM users
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
row := db.QueryRow(query, id)
|
|
||||||
if err := row.Scan(&queriedPass); err != nil {
|
|
||||||
return fmt.Errorf("error reading password from DB: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil {
|
|
||||||
return fmt.Errorf("incorrect password: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) ChangePassword(id int64, oldPass, newPass string) error {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error starting transaction: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var queriedPass string
|
|
||||||
getQuery := `
|
|
||||||
SELECT password
|
|
||||||
FROM users
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
row := tx.QueryRow(getQuery, id)
|
|
||||||
if err := row.Scan(&queriedPass); err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error reading password from DB: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("incorrect password: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error creating password hash: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
setQuery := `
|
|
||||||
UPDATE users
|
|
||||||
SET password = ?
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error updating password in DB: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("error committing transaction: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) CountEntries() (int64, error) {
|
|
||||||
var count int64
|
var count int64
|
||||||
|
|
||||||
query := `SELECT COUNT(*) FROM users`
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
||||||
row := db.QueryRow(query)
|
row := db.QueryRow(query)
|
||||||
if err := row.Scan(&count); err != nil {
|
if err := row.Scan(&count); err != nil {
|
||||||
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
|
return 0, fmt.Errorf("error counting rows in user DB: %v", err)
|
||||||
@ -167,176 +102,3 @@ func (db *DB) CountEntries() (int64, error) {
|
|||||||
|
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: No need for ID field in general
|
|
||||||
func (db *DB) GetUser(id int64) (*User, error) {
|
|
||||||
user := new(User)
|
|
||||||
query := `
|
|
||||||
SELECT id, username, first_name, last_name, role
|
|
||||||
FROM users
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
row := db.QueryRow(query, id)
|
|
||||||
if err := row.Scan(&user.ID, &user.UserName, &user.FirstName, &user.LastName, &user.Role); err != nil {
|
|
||||||
return nil, fmt.Errorf("error reading user information: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) AddTag(tagName string) error {
|
|
||||||
query := "INSERT INTO tags (name) VALUES (?)"
|
|
||||||
if _, err := db.Exec(query, tagName); err != nil {
|
|
||||||
return fmt.Errorf("error inserting tag into DB: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) GetTagList() ([]*Tag, error) {
|
|
||||||
query := "SELECT id, name FROM tags"
|
|
||||||
rows, err := db.Query(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error querying tags: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tagList := make([]*Tag, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
tag := new(Tag)
|
|
||||||
if err = rows.Scan(&tag.ID, &tag.Name); err != nil {
|
|
||||||
return nil, fmt.Errorf("error scanning tag row: %v", err)
|
|
||||||
}
|
|
||||||
tagList = append(tagList, tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tagList, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) AddArticle(a *Article) (int64, error) {
|
|
||||||
query := `
|
|
||||||
INSERT INTO articles
|
|
||||||
(title, description, content, published, author_id)
|
|
||||||
VALUES
|
|
||||||
(?, ?, ?, ?, ?)
|
|
||||||
`
|
|
||||||
|
|
||||||
result, err := db.Exec(query, a.Title, a.Desc, a.Content, a.Published, a.AuthorID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error inserting article into DB: %v", err)
|
|
||||||
}
|
|
||||||
id, err := result.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error retrieving last ID: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) GetArticle(id int64) (*Article, error) {
|
|
||||||
query := `
|
|
||||||
SELECT title, created, description, content, published, author_id
|
|
||||||
FROM articles
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
row := db.QueryRow(query, id)
|
|
||||||
|
|
||||||
article := new(Article)
|
|
||||||
var created []byte
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if err := row.Scan(&article.Title, &created, &article.Desc,
|
|
||||||
&article.Content, &article.Published, &article.AuthorID); err != nil {
|
|
||||||
return nil, fmt.Errorf("error scanning article row: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
article.ID = id
|
|
||||||
article.Created, err = time.Parse("2006-01-02 15:04:05", string(created))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing created: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return article, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) GetCertainArticles(published bool) ([]*Article, error) {
|
|
||||||
query := `
|
|
||||||
SELECT id, title, created, description, content, author_id
|
|
||||||
FROM articles
|
|
||||||
WHERE published = ?
|
|
||||||
`
|
|
||||||
rows, err := db.Query(query, published)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error querying articles: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
articleList := make([]*Article, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
article := new(Article)
|
|
||||||
var created []byte
|
|
||||||
|
|
||||||
if err = rows.Scan(&article.ID, &article.Title, &created, &article.Desc,
|
|
||||||
&article.Content, &article.AuthorID); err != nil {
|
|
||||||
return nil, fmt.Errorf("error scanning article row: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
article.Published = false
|
|
||||||
article.Created, err = time.Parse("2006-01-02 15:04:05", string(created))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing created: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
articleList = append(articleList, article)
|
|
||||||
}
|
|
||||||
|
|
||||||
return articleList, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) WriteArticleTags(articleID int64, tagIDs []int64) error {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error starting transaction: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tagID := range tagIDs {
|
|
||||||
query := `
|
|
||||||
INSERT INTO articles_tags (article_id, tag_id)
|
|
||||||
VALUES (?, ?)
|
|
||||||
`
|
|
||||||
if _, err := tx.Exec(query, articleID, tagID); err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error inserting into articles_tags: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("error committing transaction: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) GetArticleTags(articleID int64) ([]*Tag, error) {
|
|
||||||
query := `
|
|
||||||
SELECT t.id, t.name
|
|
||||||
FROM articles a
|
|
||||||
INNER JOIN articles_tags at ON a.id = at.article_id
|
|
||||||
INNER JOIN tags t ON at.tag_id = t.id
|
|
||||||
WHERE a.id = ?
|
|
||||||
`
|
|
||||||
rows, err := db.Query(query, articleID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error querying articles_tags: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tags := make([]*Tag, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
tag := new(Tag)
|
|
||||||
if err = rows.Scan(&tag.ID, &tag.Name); err != nil {
|
|
||||||
return nil, fmt.Errorf("error scanning rows: %v", err)
|
|
||||||
}
|
|
||||||
tags = append(tags, tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags, nil
|
|
||||||
}
|
|
||||||
|
@ -1,52 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getUsername() (string, error) {
|
|
||||||
user := os.Getenv("DB_USER")
|
|
||||||
if user == "" {
|
|
||||||
var err error
|
|
||||||
fmt.Printf("DB Benutzer: ")
|
|
||||||
user, err = bufio.NewReader(os.Stdin).ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error reading username: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(user), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPassword() (string, error) {
|
|
||||||
pass := os.Getenv("DB_PASS")
|
|
||||||
if pass == "" {
|
|
||||||
fmt.Printf("DB Passwort: ")
|
|
||||||
bytePass, err := term.ReadPassword(int(syscall.Stdin))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error reading password: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
pass = strings.TrimSpace(string(bytePass))
|
|
||||||
}
|
|
||||||
return pass, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCredentials() (string, string, error) {
|
|
||||||
user, err := getUsername()
|
|
||||||
if err != nil {
|
|
||||||
return "", "", fmt.Errorf("error getting username: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pass, err := getPassword()
|
|
||||||
if err != nil {
|
|
||||||
return "", "", fmt.Errorf("error getting password: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, pass, nil
|
|
||||||
}
|
|
30
cmd/model/tags.go
Normal file
30
cmd/model/tags.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func (db *DB) AddTag(tagName string) error {
|
||||||
|
query := "INSERT INTO tags (name) VALUES (?)"
|
||||||
|
if _, err := db.Exec(query, tagName); err != nil {
|
||||||
|
return fmt.Errorf("error inserting tag into DB: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetTagList() ([]*Tag, error) {
|
||||||
|
query := "SELECT id, name FROM tags"
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error querying tags: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tagList := make([]*Tag, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
tag := new(Tag)
|
||||||
|
if err = rows.Scan(&tag.ID, &tag.Name); err != nil {
|
||||||
|
return nil, fmt.Errorf("error scanning tag row: %v", err)
|
||||||
|
}
|
||||||
|
tagList = append(tagList, tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tagList, nil
|
||||||
|
}
|
133
cmd/model/users.go
Normal file
133
cmd/model/users.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DB) AddUser(user *User, pass string) error {
|
||||||
|
hashedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating password hash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
INSERT INTO users
|
||||||
|
(username, password, first_name, last_name, role)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
if _, err = db.Exec(query, user.UserName, string(hashedPass), user.FirstName, user.LastName, user.Role); err != nil {
|
||||||
|
return fmt.Errorf("error inserting user into DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetID(userName string) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT id
|
||||||
|
FROM users
|
||||||
|
WHERE username = ?
|
||||||
|
`
|
||||||
|
row := db.QueryRow(query, userName)
|
||||||
|
if err := row.Scan(&id); err != nil {
|
||||||
|
return 0, fmt.Errorf("user not in DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) CheckPassword(id int64, pass string) error {
|
||||||
|
var queriedPass string
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT password
|
||||||
|
FROM users
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
row := db.QueryRow(query, id)
|
||||||
|
if err := row.Scan(&queriedPass); err != nil {
|
||||||
|
return fmt.Errorf("error reading password from DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(pass)); err != nil {
|
||||||
|
return fmt.Errorf("incorrect password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ChangePassword(id int64, oldPass, newPass string) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error starting transaction: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var queriedPass string
|
||||||
|
getQuery := `
|
||||||
|
SELECT password
|
||||||
|
FROM users
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
row := tx.QueryRow(getQuery, id)
|
||||||
|
if err := row.Scan(&queriedPass); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error reading password from DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(queriedPass), []byte(oldPass)); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("incorrect password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newHashedPass, err := bcrypt.GenerateFromPassword([]byte(newPass), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error creating password hash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setQuery := `
|
||||||
|
UPDATE users
|
||||||
|
SET password = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
if _, err = tx.Exec(setQuery, string(newHashedPass), id); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error updating password in DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("error committing transaction: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: No need for ID field in general
|
||||||
|
func (db *DB) GetUser(id int64) (*User, error) {
|
||||||
|
user := new(User)
|
||||||
|
query := `
|
||||||
|
SELECT id, username, first_name, last_name, role
|
||||||
|
FROM users
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
row := db.QueryRow(query, id)
|
||||||
|
if err := row.Scan(&user.ID, &user.UserName, &user.FirstName, &user.LastName, &user.Role); err != nil {
|
||||||
|
return nil, fmt.Errorf("error reading user information: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
@ -93,7 +93,7 @@ func AddUser(db *model.DB, s *control.CookieStore) http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
num, err := db.CountEntries()
|
num, err := db.CountEntries("users")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
@ -29,7 +29,7 @@ func saveSession(w http.ResponseWriter, r *http.Request, s *control.CookieStore,
|
|||||||
|
|
||||||
func HomePage(db *model.DB, s *control.CookieStore) http.HandlerFunc {
|
func HomePage(db *model.DB, s *control.CookieStore) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
numRows, err := db.CountEntries()
|
numRows, err := db.CountEntries("users")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user