forked from jason/cpolis
		
	Implemented retry logic on all transactions
This commit is contained in:
		@@ -3,31 +3,47 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math"
 | 
			
		||||
	"math/rand/v2"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (db *DB) WriteArticleTags(articleID int64, tagIDs []int64) error {
 | 
			
		||||
	tx, err := db.Begin()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error starting transaction: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tagID := range tagIDs {
 | 
			
		||||
		query := `
 | 
			
		||||
        INSERT INTO articles_tags (article_id, tag_id)
 | 
			
		||||
        VALUES (?, ?)
 | 
			
		||||
        `
 | 
			
		||||
		if _, err := tx.Exec(query, articleID, tagID); err != nil {
 | 
			
		||||
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
			
		||||
				log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
 | 
			
		||||
	for i := 0; i < TxMaxRetries; i++ {
 | 
			
		||||
		err := func() error {
 | 
			
		||||
			tx, err := db.Begin()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return fmt.Errorf("error starting transaction: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			return fmt.Errorf("error inserting into articles_tags: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return fmt.Errorf("error committing transaction: %v", err)
 | 
			
		||||
			for _, tagID := range tagIDs {
 | 
			
		||||
				query := `
 | 
			
		||||
                INSERT INTO articles_tags (article_id, tag_id)
 | 
			
		||||
                VALUES (?, ?)
 | 
			
		||||
                `
 | 
			
		||||
				if _, err := tx.Exec(query, articleID, tagID); err != nil {
 | 
			
		||||
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
			
		||||
						log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
 | 
			
		||||
					}
 | 
			
		||||
					return fmt.Errorf("error inserting into articles_tags: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err = tx.Commit(); err != nil {
 | 
			
		||||
				return fmt.Errorf("error committing transaction: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		log.Println(err)
 | 
			
		||||
		waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second
 | 
			
		||||
		jitter := time.Duration(rand.IntN(1000)) * time.Millisecond
 | 
			
		||||
		time.Sleep(waitTime + jitter)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) GetArticleTags(articleID int64) ([]*Tag, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -18,20 +18,16 @@ import (
 | 
			
		||||
 | 
			
		||||
var TxMaxRetries = 3
 | 
			
		||||
 | 
			
		||||
type DB struct {
 | 
			
		||||
	*sql.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tx struct {
 | 
			
		||||
	*sql.Tx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Attribute struct {
 | 
			
		||||
	Value   interface{}
 | 
			
		||||
	Table   string
 | 
			
		||||
	AttName string
 | 
			
		||||
	ID      int64
 | 
			
		||||
}
 | 
			
		||||
type (
 | 
			
		||||
	DB        struct{ *sql.DB }
 | 
			
		||||
	Tx        struct{ *sql.Tx }
 | 
			
		||||
	Attribute struct {
 | 
			
		||||
		Value   interface{}
 | 
			
		||||
		Table   string
 | 
			
		||||
		AttName string
 | 
			
		||||
		ID      int64
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getUsername() (string, error) {
 | 
			
		||||
	user := os.Getenv("DB_USER")
 | 
			
		||||
@@ -126,8 +122,8 @@ func (db *DB) UpdateAttributes(a ...*Attribute) error {
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		log.Println(err)
 | 
			
		||||
 | 
			
		||||
		log.Println(err)
 | 
			
		||||
		waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second
 | 
			
		||||
		jitter := time.Duration(rand.IntN(1000)) * time.Millisecond
 | 
			
		||||
		time.Sleep(waitTime + jitter)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,9 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math"
 | 
			
		||||
	"math/rand/v2"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
)
 | 
			
		||||
@@ -135,3 +138,61 @@ func (db *DB) GetUser(id int64) (*User, error) {
 | 
			
		||||
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) UpdateUserAttributes(id int64, user, first, last, oldPass, newPass, newPass2 string) error {
 | 
			
		||||
	passwordEmpty := true
 | 
			
		||||
	if len(newPass) > 0 || len(newPass2) > 0 {
 | 
			
		||||
		if newPass != newPass2 {
 | 
			
		||||
			return fmt.Errorf("error: passwords do not match")
 | 
			
		||||
		}
 | 
			
		||||
		passwordEmpty = false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := new(Tx)
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < TxMaxRetries; i++ {
 | 
			
		||||
		err := func() error {
 | 
			
		||||
			tx.Tx, err = db.Begin()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return fmt.Errorf("error starting transaction: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !passwordEmpty {
 | 
			
		||||
				if err = tx.ChangePassword(id, oldPass, newPass); err != nil {
 | 
			
		||||
					if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
			
		||||
						log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
 | 
			
		||||
					}
 | 
			
		||||
					return fmt.Errorf("error changing password: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err = tx.UpdateAttributes(
 | 
			
		||||
				&Attribute{Table: "users", ID: id, AttName: "username", Value: user},
 | 
			
		||||
				&Attribute{Table: "users", ID: id, AttName: "first_name", Value: first},
 | 
			
		||||
				&Attribute{Table: "users", ID: id, AttName: "last_name", Value: last},
 | 
			
		||||
			); err != nil {
 | 
			
		||||
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
			
		||||
					log.Fatalf("error: transaction error: %v, rollback error: %v", err, rollbackErr)
 | 
			
		||||
				}
 | 
			
		||||
				return fmt.Errorf("error updating attributes in DB: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err = tx.Commit(); err != nil {
 | 
			
		||||
				return fmt.Errorf("error committing transaction: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		log.Println(err)
 | 
			
		||||
		waitTime := time.Duration(math.Pow(2, float64(i))) * time.Second
 | 
			
		||||
		jitter := time.Duration(rand.IntN(1000)) * time.Millisecond
 | 
			
		||||
		time.Sleep(waitTime + jitter)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Errorf("error: %v unsuccessful retries for DB operation, aborting", TxMaxRetries)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -192,47 +192,17 @@ func UpdateUser(db *model.DB, s *control.CookieStore) http.HandlerFunc {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tx, err := db.StartTransaction()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			http.Error(w, err.Error(), http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(newPass) > 0 || len(newPass2) > 0 {
 | 
			
		||||
			if newPass != newPass2 {
 | 
			
		||||
				tx.RollbackTransaction()
 | 
			
		||||
				userData.Msg = "Die Passwörter stimmen nicht überein."
 | 
			
		||||
				tmpl, err := template.ParseFiles("web/templates/edit-user.html")
 | 
			
		||||
				tmpl = template.Must(tmpl, err)
 | 
			
		||||
				tmpl.ExecuteTemplate(w, "page-content", userData)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err = tx.ChangePassword(userData.ID, oldPass, newPass); err != nil {
 | 
			
		||||
				log.Println(err)
 | 
			
		||||
				userData.Msg = "Das alte Passwort ist nicht korrekt."
 | 
			
		||||
				tmpl, err := template.ParseFiles("web/templates/edit-user.html")
 | 
			
		||||
				tmpl = template.Must(tmpl, err)
 | 
			
		||||
				tmpl.ExecuteTemplate(w, "page-content", userData)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err = tx.UpdateAttributes(
 | 
			
		||||
			&model.Attribute{Table: "users", ID: userData.ID, AttName: "username", Value: userData.UserName},
 | 
			
		||||
			&model.Attribute{Table: "users", ID: userData.ID, AttName: "first_name", Value: userData.FirstName},
 | 
			
		||||
			&model.Attribute{Table: "users", ID: userData.ID, AttName: "last_name", Value: userData.LastName},
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			http.Error(w, err.Error(), http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err = tx.Commit(); err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			http.Error(w, err.Error(), http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		if err = db.UpdateUserAttributes(
 | 
			
		||||
			userData.ID,
 | 
			
		||||
			userData.UserName,
 | 
			
		||||
			userData.FirstName,
 | 
			
		||||
			userData.LastName,
 | 
			
		||||
			oldPass,
 | 
			
		||||
			newPass,
 | 
			
		||||
			newPass2); err != nil {
 | 
			
		||||
			userData.Msg = "Aktualisierung der Benutzerdaten fehlgeschlagen."
 | 
			
		||||
			tmpl, err := template.ParseFiles("web/templates/edit-user.html")
 | 
			
		||||
			template.Must(tmpl, err).ExecuteTemplate(w, "page-content", userData)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tmpl, err := template.ParseFiles("web/templates/hub.html")
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user