From c45df4bf1adf31836a686b589f7ee22e6b2dfd2b Mon Sep 17 00:00:00 2001 From: Jason Streifling Date: Fri, 15 Mar 2024 18:37:24 +0100 Subject: [PATCH] Implemented retry logic on all transactions --- cmd/model/articles_tags.go | 54 +++++++++++++++++++++------------ cmd/model/db.go | 26 +++++++--------- cmd/model/users.go | 61 ++++++++++++++++++++++++++++++++++++++ cmd/view/users.go | 52 +++++++------------------------- 4 files changed, 118 insertions(+), 75 deletions(-) diff --git a/cmd/model/articles_tags.go b/cmd/model/articles_tags.go index e2cbe12..9dcc280 100644 --- a/cmd/model/articles_tags.go +++ b/cmd/model/articles_tags.go @@ -3,31 +3,47 @@ package model import ( "fmt" "log" + "math" + "math/rand/v2" + "time" ) func (db *DB) WriteArticleTags(articleID int64, tagIDs []int64) error { - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("error starting transaction: %v", err) - } - - for _, tagID := range tagIDs { - query := ` - INSERT INTO articles_tags (article_id, tag_id) - VALUES (?, ?) - ` - if _, err := tx.Exec(query, articleID, tagID); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + for i := 0; i < TxMaxRetries; i++ { + err := func() error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("error starting transaction: %v", err) } - return fmt.Errorf("error inserting into articles_tags: %v", err) - } - } - if err = tx.Commit(); err != nil { - return fmt.Errorf("error committing transaction: %v", err) + for _, tagID := range tagIDs { + query := ` + INSERT INTO articles_tags (article_id, tag_id) + VALUES (?, ?) + ` + if _, err := tx.Exec(query, articleID, tagID); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error inserting into articles_tags: %v", err) + } + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("error committing transaction: %v", err) + } + return nil + }() + if err == nil { + return nil + } + + log.Println(err) + waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second + jitter := time.Duration(rand.IntN(1000)) * time.Millisecond + time.Sleep(waitTime + jitter) } - return nil + return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries) } func (db *DB) GetArticleTags(articleID int64) ([]*Tag, error) { diff --git a/cmd/model/db.go b/cmd/model/db.go index 4720b5b..cc11f22 100644 --- a/cmd/model/db.go +++ b/cmd/model/db.go @@ -18,20 +18,16 @@ import ( var TxMaxRetries = 3 -type DB struct { - *sql.DB -} - -type Tx struct { - *sql.Tx -} - -type Attribute struct { - Value interface{} - Table string - AttName string - ID int64 -} +type ( + DB struct{ *sql.DB } + Tx struct{ *sql.Tx } + Attribute struct { + Value interface{} + Table string + AttName string + ID int64 + } +) func getUsername() (string, error) { user := os.Getenv("DB_USER") @@ -126,8 +122,8 @@ func (db *DB) UpdateAttributes(a ...*Attribute) error { if err == nil { return nil } - log.Println(err) + log.Println(err) waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second jitter := time.Duration(rand.IntN(1000)) * time.Millisecond time.Sleep(waitTime + jitter) diff --git a/cmd/model/users.go b/cmd/model/users.go index d5d45f3..1f2b4b9 100644 --- a/cmd/model/users.go +++ b/cmd/model/users.go @@ -3,6 +3,9 @@ package model import ( "fmt" "log" + "math" + "math/rand/v2" + "time" "golang.org/x/crypto/bcrypt" ) @@ -135,3 +138,61 @@ func (db *DB) GetUser(id int64) (*User, error) { return user, nil } + +func (db *DB) UpdateUserAttributes(id int64, user, first, last, oldPass, newPass, newPass2 string) error { + passwordEmpty := true + if len(newPass) > 0 || len(newPass2) > 0 { + if newPass != newPass2 { + return fmt.Errorf("error: passwords do not match") + } + passwordEmpty = false + } + + tx := new(Tx) + var err error + + for i := 0; i < TxMaxRetries; i++ { + err := func() error { + tx.Tx, err = db.Begin() + if err != nil { + return fmt.Errorf("error starting transaction: %v", err) + } + + if !passwordEmpty { + if err = tx.ChangePassword(id, oldPass, newPass); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error changing password: %v", err) + } + } + + if err = tx.UpdateAttributes( + &Attribute{Table: "users", ID: id, AttName: "username", Value: user}, + &Attribute{Table: "users", ID: id, AttName: "first_name", Value: first}, + &Attribute{Table: "users", ID: id, AttName: "last_name", Value: last}, + ); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("error updating attributes in DB: %v", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("error committing transaction: %v", err) + } + + return nil + }() + if err == nil { + return nil + } + + log.Println(err) + waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second + jitter := time.Duration(rand.IntN(1000)) * time.Millisecond + time.Sleep(waitTime + jitter) + } + + return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries) +} diff --git a/cmd/view/users.go b/cmd/view/users.go index acd1e9d..3403eeb 100644 --- a/cmd/view/users.go +++ b/cmd/view/users.go @@ -192,47 +192,17 @@ func UpdateUser(db *model.DB, s *control.CookieStore) http.HandlerFunc { } } - tx, err := db.StartTransaction() - if err != nil { - log.Println(err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - if len(newPass) > 0 || len(newPass2) > 0 { - if newPass != newPass2 { - tx.RollbackTransaction() - userData.Msg = "Die Passwörter stimmen nicht überein." - tmpl, err := template.ParseFiles("web/templates/edit-user.html") - tmpl = template.Must(tmpl, err) - tmpl.ExecuteTemplate(w, "page-content", userData) - return - } - - if err = tx.ChangePassword(userData.ID, oldPass, newPass); err != nil { - log.Println(err) - userData.Msg = "Das alte Passwort ist nicht korrekt." - tmpl, err := template.ParseFiles("web/templates/edit-user.html") - tmpl = template.Must(tmpl, err) - tmpl.ExecuteTemplate(w, "page-content", userData) - return - } - } - - if err = tx.UpdateAttributes( - &model.Attribute{Table: "users", ID: userData.ID, AttName: "username", Value: userData.UserName}, - &model.Attribute{Table: "users", ID: userData.ID, AttName: "first_name", Value: userData.FirstName}, - &model.Attribute{Table: "users", ID: userData.ID, AttName: "last_name", Value: userData.LastName}, - ); err != nil { - log.Println(err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - if err = tx.Commit(); err != nil { - log.Println(err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return + if err = db.UpdateUserAttributes( + userData.ID, + userData.UserName, + userData.FirstName, + userData.LastName, + oldPass, + newPass, + newPass2); err != nil { + userData.Msg = "Aktualisierung der Benutzerdaten fehlgeschlagen." + tmpl, err := template.ParseFiles("web/templates/edit-user.html") + template.Must(tmpl, err).ExecuteTemplate(w, "page-content", userData) } tmpl, err := template.ParseFiles("web/templates/hub.html")