kevo/pkg/engine/transaction/registry.go
Jeremy Tregunna 7e226825df
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 9m48s
fix: engine refactor bugfix fest, go fmt
2025-04-25 23:36:08 -06:00

297 lines
6.9 KiB
Go

package transaction
import (
"context"
"fmt"
"sync"
"time"
"github.com/KevoDB/kevo/pkg/engine/interfaces"
)
// Registry manages engine transactions using the new transaction system
type Registry struct {
mu sync.RWMutex
transactions map[string]interfaces.Transaction
nextID uint64
cleanupTicker *time.Ticker
stopCleanup chan struct{}
connectionTxs map[string]map[string]struct{}
}
// NewRegistry creates a new transaction registry
func NewRegistry() *Registry {
r := &Registry{
transactions: make(map[string]interfaces.Transaction),
connectionTxs: make(map[string]map[string]struct{}),
stopCleanup: make(chan struct{}),
}
// Start periodic cleanup
r.cleanupTicker = time.NewTicker(5 * time.Second)
go r.cleanupStaleTx()
return r
}
// cleanupStaleTx periodically checks for and removes stale transactions
func (r *Registry) cleanupStaleTx() {
for {
select {
case <-r.cleanupTicker.C:
r.cleanupStaleTransactions()
case <-r.stopCleanup:
r.cleanupTicker.Stop()
return
}
}
}
// cleanupStaleTransactions removes transactions that have been idle for too long
func (r *Registry) cleanupStaleTransactions() {
r.mu.Lock()
defer r.mu.Unlock()
maxAge := 2 * time.Minute
now := time.Now()
// Find stale transactions
var staleIDs []string
for id, tx := range r.transactions {
// Check if the transaction is a Transaction type that has a startTime field
// If not, we assume it's been around for a while and might need cleanup
needsCleanup := true
// For our transactions, we can check for creation time
if ourTx, ok := tx.(*Transaction); ok {
// Only clean up if it's older than maxAge
if now.Sub(ourTx.startTime) < maxAge {
needsCleanup = false
}
}
if needsCleanup {
staleIDs = append(staleIDs, id)
}
}
if len(staleIDs) > 0 {
fmt.Printf("Cleaning up %d stale transactions\n", len(staleIDs))
}
// Clean up stale transactions
for _, id := range staleIDs {
if tx, exists := r.transactions[id]; exists {
// Try to rollback the transaction
_ = tx.Rollback() // Ignore errors during cleanup
// Remove from connection tracking
for connID, txs := range r.connectionTxs {
if _, ok := txs[id]; ok {
delete(txs, id)
// If connection has no more transactions, remove it
if len(txs) == 0 {
delete(r.connectionTxs, connID)
}
break
}
}
// Remove from main transactions map
delete(r.transactions, id)
fmt.Printf("Removed stale transaction: %s\n", id)
}
}
}
// Begin starts a new transaction
func (r *Registry) Begin(ctx context.Context, eng interfaces.Engine, readOnly bool) (string, error) {
// Extract connection ID from context
connectionID := "unknown"
if p, ok := ctx.Value("peer").(string); ok {
connectionID = p
}
// Create a timeout context for transaction creation
timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Create a channel to receive the transaction result
type txResult struct {
tx interfaces.Transaction
err error
}
resultCh := make(chan txResult, 1)
// Start transaction in a goroutine
go func() {
tx, err := eng.BeginTransaction(readOnly)
select {
case resultCh <- txResult{tx, err}:
// Successfully sent result
case <-timeoutCtx.Done():
// Context timed out, but try to rollback if we got a transaction
if tx != nil {
tx.Rollback()
}
}
}()
// Wait for result or timeout
select {
case result := <-resultCh:
if result.err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
}
r.mu.Lock()
defer r.mu.Unlock()
// Generate transaction ID
r.nextID++
txID := fmt.Sprintf("tx-%d", r.nextID)
// Store the transaction in the registry
r.transactions[txID] = result.tx
// Track by connection ID
if _, exists := r.connectionTxs[connectionID]; !exists {
r.connectionTxs[connectionID] = make(map[string]struct{})
}
r.connectionTxs[connectionID][txID] = struct{}{}
fmt.Printf("Created transaction: %s (connection: %s)\n", txID, connectionID)
return txID, nil
case <-timeoutCtx.Done():
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
}
}
// Get retrieves a transaction by ID
func (r *Registry) Get(txID string) (interfaces.Transaction, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
tx, exists := r.transactions[txID]
if !exists {
return nil, false
}
return tx, true
}
// Remove removes a transaction from the registry
func (r *Registry) Remove(txID string) {
r.mu.Lock()
defer r.mu.Unlock()
_, exists := r.transactions[txID]
if !exists {
return
}
// Remove from connection tracking
for connID, txs := range r.connectionTxs {
if _, ok := txs[txID]; ok {
delete(txs, txID)
// If connection has no more transactions, remove it
if len(txs) == 0 {
delete(r.connectionTxs, connID)
}
break
}
}
// Remove from transactions map
delete(r.transactions, txID)
}
// CleanupConnection rolls back and removes all transactions for a connection
func (r *Registry) CleanupConnection(connectionID string) {
r.mu.Lock()
defer r.mu.Unlock()
txIDs, exists := r.connectionTxs[connectionID]
if !exists {
return
}
fmt.Printf("Cleaning up %d transactions for disconnected connection %s\n",
len(txIDs), connectionID)
// Rollback each transaction
for txID := range txIDs {
if tx, ok := r.transactions[txID]; ok {
// Rollback and ignore errors since we're cleaning up
_ = tx.Rollback()
// Remove from transactions map
delete(r.transactions, txID)
}
}
// Remove the connection entry
delete(r.connectionTxs, connectionID)
}
// GracefulShutdown cleans up all transactions
func (r *Registry) GracefulShutdown(ctx context.Context) error {
// Stop the cleanup goroutine
close(r.stopCleanup)
r.cleanupTicker.Stop()
r.mu.Lock()
defer r.mu.Unlock()
var lastErr error
// Copy transaction IDs to avoid modifying during iteration
ids := make([]string, 0, len(r.transactions))
for id := range r.transactions {
ids = append(ids, id)
}
// Rollback each transaction with a timeout
for _, id := range ids {
tx, exists := r.transactions[id]
if !exists {
continue
}
// Use a timeout for each rollback operation
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
// Create a channel for the rollback result
doneCh := make(chan error, 1)
// Execute rollback in goroutine to handle potential hangs
go func(t interfaces.Transaction) {
doneCh <- t.Rollback()
}(tx)
// Wait for rollback or timeout
var err error
select {
case err = <-doneCh:
// Rollback completed
case <-rollbackCtx.Done():
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
}
cancel() // Clean up context
// Record error if any
if err != nil {
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
}
// Always remove transaction from map
delete(r.transactions, id)
}
// Clear the connection tracking map
r.connectionTxs = make(map[string]map[string]struct{})
return lastErr
}