chore: formatting
Some checks failed
Go Tests / Run Tests (1.24.2) (push) Failing after 15m7s

This commit is contained in:
Jeremy Tregunna 2025-05-02 15:41:46 -06:00
parent 7e744fe85b
commit 9a98349115
Signed by: jer
GPG Key ID: 1278B36BA6F5D5E4
21 changed files with 363 additions and 362 deletions

View File

@ -205,7 +205,7 @@ func parseFlags() Config {
} }
fmt.Printf("DEBUG: Config created: ReplicationEnabled=%v, ReplicationMode=%s\n", fmt.Printf("DEBUG: Config created: ReplicationEnabled=%v, ReplicationMode=%s\n",
config.ReplicationEnabled, config.ReplicationMode) config.ReplicationEnabled, config.ReplicationMode)
return config return config
} }
@ -219,7 +219,7 @@ func runServer(eng *engine.Engine, config Config) {
// Create and start the server // Create and start the server
fmt.Printf("DEBUG: Before server creation: ReplicationEnabled=%v, ReplicationMode=%s\n", fmt.Printf("DEBUG: Before server creation: ReplicationEnabled=%v, ReplicationMode=%s\n",
config.ReplicationEnabled, config.ReplicationMode) config.ReplicationEnabled, config.ReplicationMode)
server := NewServer(eng, config) server := NewServer(eng, config)
// Start the server (non-blocking) // Start the server (non-blocking)

View File

@ -131,7 +131,7 @@ func (s *Server) Start() error {
fmt.Printf("DEBUG: Using replication manager for role %s\n", s.config.ReplicationMode) fmt.Printf("DEBUG: Using replication manager for role %s\n", s.config.ReplicationMode)
repManager = s.replicationManager repManager = s.replicationManager
} else { } else {
fmt.Printf("DEBUG: No replication manager available. ReplicationEnabled: %v, Manager nil: %v\n", fmt.Printf("DEBUG: No replication manager available. ReplicationEnabled: %v, Manager nil: %v\n",
s.config.ReplicationEnabled, s.replicationManager == nil) s.config.ReplicationEnabled, s.replicationManager == nil)
} }

View File

@ -30,7 +30,7 @@ func TestTransactionManager(t *testing.T) {
// Get the transaction manager // Get the transaction manager
txManager := eng.GetTransactionManager() txManager := eng.GetTransactionManager()
// Test read-write transaction // Test read-write transaction
rwTx, err := txManager.BeginTransaction(false) rwTx, err := txManager.BeginTransaction(false)
if err != nil { if err != nil {
@ -39,12 +39,12 @@ func TestTransactionManager(t *testing.T) {
if rwTx.IsReadOnly() { if rwTx.IsReadOnly() {
t.Fatal("Expected non-read-only transaction") t.Fatal("Expected non-read-only transaction")
} }
// Test committing the transaction // Test committing the transaction
if err := rwTx.Commit(); err != nil { if err := rwTx.Commit(); err != nil {
t.Fatalf("Failed to commit transaction: %v", err) t.Fatalf("Failed to commit transaction: %v", err)
} }
// Test read-only transaction // Test read-only transaction
roTx, err := txManager.BeginTransaction(true) roTx, err := txManager.BeginTransaction(true)
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func TestTransactionManager(t *testing.T) {
if !roTx.IsReadOnly() { if !roTx.IsReadOnly() {
t.Fatal("Expected read-only transaction") t.Fatal("Expected read-only transaction")
} }
// Test rollback // Test rollback
if err := roTx.Rollback(); err != nil { if err := roTx.Rollback(); err != nil {
t.Fatalf("Failed to rollback transaction: %v", err) t.Fatalf("Failed to rollback transaction: %v", err)

View File

@ -3,7 +3,7 @@ package transaction
import ( import (
"context" "context"
"sync" "sync"
"github.com/KevoDB/kevo/pkg/common/iterator" "github.com/KevoDB/kevo/pkg/common/iterator"
"github.com/KevoDB/kevo/pkg/engine/interfaces" "github.com/KevoDB/kevo/pkg/engine/interfaces"
"github.com/KevoDB/kevo/pkg/stats" "github.com/KevoDB/kevo/pkg/stats"
@ -139,7 +139,7 @@ func (w *registryWrapper) GracefulShutdown(ctx context.Context) error {
func NewManager(storage interfaces.StorageManager, statsCollector stats.Collector) interfaces.TransactionManager { func NewManager(storage interfaces.StorageManager, statsCollector stats.Collector) interfaces.TransactionManager {
// Create a storage adapter that works with our new transaction implementation // Create a storage adapter that works with our new transaction implementation
adapter := &storageAdapter{storage: storage} adapter := &storageAdapter{storage: storage}
// Create the new transaction manager and wrap it // Create the new transaction manager and wrap it
return &managerWrapper{ return &managerWrapper{
inner: tx.NewManager(adapter, statsCollector), inner: tx.NewManager(adapter, statsCollector),
@ -152,4 +152,4 @@ func NewRegistry() interfaces.TxRegistry {
return &registryWrapper{ return &registryWrapper{
inner: tx.NewRegistry(), inner: tx.NewRegistry(),
} }
} }

View File

@ -125,98 +125,99 @@ func (m *MockStatsCollector) StartRecovery() time.Time {
return time.Now() return time.Now()
} }
func (m *MockStatsCollector) FinishRecovery(startTime time.Time, filesRecovered, entriesRecovered, corruptedEntries uint64) {} func (m *MockStatsCollector) FinishRecovery(startTime time.Time, filesRecovered, entriesRecovered, corruptedEntries uint64) {
}
func TestForwardingLayer(t *testing.T) { func TestForwardingLayer(t *testing.T) {
// Create mocks // Create mocks
storage := &MockStorage{} storage := &MockStorage{}
statsCollector := &MockStatsCollector{} statsCollector := &MockStatsCollector{}
// Create the manager through the forwarding layer // Create the manager through the forwarding layer
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Verify the manager was created // Verify the manager was created
if manager == nil { if manager == nil {
t.Fatal("Expected manager to be created, got nil") t.Fatal("Expected manager to be created, got nil")
} }
// Get the RWLock // Get the RWLock
rwLock := manager.GetRWLock() rwLock := manager.GetRWLock()
if rwLock == nil { if rwLock == nil {
t.Fatal("Expected non-nil RWLock") t.Fatal("Expected non-nil RWLock")
} }
// Test transaction creation // Test transaction creation
tx, err := manager.BeginTransaction(true) tx, err := manager.BeginTransaction(true)
if err != nil { if err != nil {
t.Fatalf("Unexpected error beginning transaction: %v", err) t.Fatalf("Unexpected error beginning transaction: %v", err)
} }
// Verify it's a read-only transaction // Verify it's a read-only transaction
if !tx.IsReadOnly() { if !tx.IsReadOnly() {
t.Error("Expected read-only transaction") t.Error("Expected read-only transaction")
} }
// Test some operations // Test some operations
_, err = tx.Get([]byte("key")) _, err = tx.Get([]byte("key"))
if err != nil { if err != nil {
t.Errorf("Unexpected error in Get: %v", err) t.Errorf("Unexpected error in Get: %v", err)
} }
// Commit the transaction // Commit the transaction
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
t.Errorf("Unexpected error committing transaction: %v", err) t.Errorf("Unexpected error committing transaction: %v", err)
} }
// Create a read-write transaction // Create a read-write transaction
tx, err = manager.BeginTransaction(false) tx, err = manager.BeginTransaction(false)
if err != nil { if err != nil {
t.Fatalf("Unexpected error beginning transaction: %v", err) t.Fatalf("Unexpected error beginning transaction: %v", err)
} }
// Verify it's a read-write transaction // Verify it's a read-write transaction
if tx.IsReadOnly() { if tx.IsReadOnly() {
t.Error("Expected read-write transaction") t.Error("Expected read-write transaction")
} }
// Test put operation // Test put operation
err = tx.Put([]byte("key"), []byte("value")) err = tx.Put([]byte("key"), []byte("value"))
if err != nil { if err != nil {
t.Errorf("Unexpected error in Put: %v", err) t.Errorf("Unexpected error in Put: %v", err)
} }
// Test delete operation // Test delete operation
err = tx.Delete([]byte("key")) err = tx.Delete([]byte("key"))
if err != nil { if err != nil {
t.Errorf("Unexpected error in Delete: %v", err) t.Errorf("Unexpected error in Delete: %v", err)
} }
// Test iterator // Test iterator
it := tx.NewIterator() it := tx.NewIterator()
if it == nil { if it == nil {
t.Error("Expected non-nil iterator") t.Error("Expected non-nil iterator")
} }
// Test range iterator // Test range iterator
rangeIt := tx.NewRangeIterator([]byte("a"), []byte("z")) rangeIt := tx.NewRangeIterator([]byte("a"), []byte("z"))
if rangeIt == nil { if rangeIt == nil {
t.Error("Expected non-nil range iterator") t.Error("Expected non-nil range iterator")
} }
// Rollback the transaction // Rollback the transaction
err = tx.Rollback() err = tx.Rollback()
if err != nil { if err != nil {
t.Errorf("Unexpected error rolling back transaction: %v", err) t.Errorf("Unexpected error rolling back transaction: %v", err)
} }
// Verify IncrementTxCompleted and IncrementTxAborted are working // Verify IncrementTxCompleted and IncrementTxAborted are working
manager.IncrementTxCompleted() manager.IncrementTxCompleted()
manager.IncrementTxAborted() manager.IncrementTxAborted()
// Test the registry creation // Test the registry creation
registry := NewRegistry() registry := NewRegistry()
if registry == nil { if registry == nil {
t.Fatal("Expected registry to be created, got nil") t.Fatal("Expected registry to be created, got nil")
} }
} }

View File

@ -6,8 +6,8 @@ import (
const ( const (
ReplicationModeStandalone = "standalone" ReplicationModeStandalone = "standalone"
ReplicationModePrimary = "primary" ReplicationModePrimary = "primary"
ReplicationModeReplica = "replica" ReplicationModeReplica = "replica"
) )
// ReplicationNodeInfo contains information about a node in the replication topology // ReplicationNodeInfo contains information about a node in the replication topology

View File

@ -199,7 +199,7 @@ func (it *BufferIterator) Next() bool {
it.SeekToFirst() it.SeekToFirst()
return it.Valid() return it.Valid()
} }
if it.position >= len(it.operations)-1 { if it.position >= len(it.operations)-1 {
it.position = -1 it.position = -1
return false return false
@ -236,4 +236,4 @@ func (it *BufferIterator) IsTombstone() bool {
return false return false
} }
return it.operations[it.position].IsDelete return it.operations[it.position].IsDelete
} }

View File

@ -7,21 +7,21 @@ import (
func TestBufferBasicOperations(t *testing.T) { func TestBufferBasicOperations(t *testing.T) {
b := NewBuffer() b := NewBuffer()
// Test initial state // Test initial state
if b.Size() != 0 { if b.Size() != 0 {
t.Errorf("Expected empty buffer, got size %d", b.Size()) t.Errorf("Expected empty buffer, got size %d", b.Size())
} }
// Test Put operation // Test Put operation
key1 := []byte("key1") key1 := []byte("key1")
value1 := []byte("value1") value1 := []byte("value1")
b.Put(key1, value1) b.Put(key1, value1)
if b.Size() != 1 { if b.Size() != 1 {
t.Errorf("Expected buffer size 1, got %d", b.Size()) t.Errorf("Expected buffer size 1, got %d", b.Size())
} }
// Test Get operation // Test Get operation
val, found := b.Get(key1) val, found := b.Get(key1)
if !found { if !found {
@ -30,15 +30,15 @@ func TestBufferBasicOperations(t *testing.T) {
if !bytes.Equal(val, value1) { if !bytes.Equal(val, value1) {
t.Errorf("Expected value %s, got %s", value1, val) t.Errorf("Expected value %s, got %s", value1, val)
} }
// Test overwriting a key // Test overwriting a key
newValue1 := []byte("new_value1") newValue1 := []byte("new_value1")
b.Put(key1, newValue1) b.Put(key1, newValue1)
if b.Size() != 1 { if b.Size() != 1 {
t.Errorf("Expected buffer size to remain 1 after overwrite, got %d", b.Size()) t.Errorf("Expected buffer size to remain 1 after overwrite, got %d", b.Size())
} }
val, found = b.Get(key1) val, found = b.Get(key1)
if !found { if !found {
t.Errorf("Expected to find key %s after overwrite, but it was not found", key1) t.Errorf("Expected to find key %s after overwrite, but it was not found", key1)
@ -46,14 +46,14 @@ func TestBufferBasicOperations(t *testing.T) {
if !bytes.Equal(val, newValue1) { if !bytes.Equal(val, newValue1) {
t.Errorf("Expected updated value %s, got %s", newValue1, val) t.Errorf("Expected updated value %s, got %s", newValue1, val)
} }
// Test Delete operation // Test Delete operation
b.Delete(key1) b.Delete(key1)
if b.Size() != 1 { if b.Size() != 1 {
t.Errorf("Expected buffer size to remain 1 after delete, got %d", b.Size()) t.Errorf("Expected buffer size to remain 1 after delete, got %d", b.Size())
} }
val, found = b.Get(key1) val, found = b.Get(key1)
if !found { if !found {
t.Errorf("Expected to find key %s after delete op, but it was not found", key1) t.Errorf("Expected to find key %s after delete op, but it was not found", key1)
@ -61,10 +61,10 @@ func TestBufferBasicOperations(t *testing.T) {
if val != nil { if val != nil {
t.Errorf("Expected nil value after delete, got %s", val) t.Errorf("Expected nil value after delete, got %s", val)
} }
// Test Clear operation // Test Clear operation
b.Clear() b.Clear()
if b.Size() != 0 { if b.Size() != 0 {
t.Errorf("Expected empty buffer after clear, got size %d", b.Size()) t.Errorf("Expected empty buffer after clear, got size %d", b.Size())
} }
@ -72,7 +72,7 @@ func TestBufferBasicOperations(t *testing.T) {
func TestBufferOperationsMethod(t *testing.T) { func TestBufferOperationsMethod(t *testing.T) {
b := NewBuffer() b := NewBuffer()
// Add multiple operations // Add multiple operations
keys := [][]byte{ keys := [][]byte{
[]byte("c"), []byte("c"),
@ -84,18 +84,18 @@ func TestBufferOperationsMethod(t *testing.T) {
[]byte("value_a"), []byte("value_a"),
[]byte("value_b"), []byte("value_b"),
} }
b.Put(keys[0], values[0]) b.Put(keys[0], values[0])
b.Put(keys[1], values[1]) b.Put(keys[1], values[1])
b.Put(keys[2], values[2]) b.Put(keys[2], values[2])
// Test Operations() returns operations sorted by key // Test Operations() returns operations sorted by key
ops := b.Operations() ops := b.Operations()
if len(ops) != 3 { if len(ops) != 3 {
t.Errorf("Expected 3 operations, got %d", len(ops)) t.Errorf("Expected 3 operations, got %d", len(ops))
} }
// Check the order (should be sorted by key: a, b, c) // Check the order (should be sorted by key: a, b, c)
expected := [][]byte{keys[1], keys[2], keys[0]} expected := [][]byte{keys[1], keys[2], keys[0]}
for i, op := range ops { for i, op := range ops {
@ -103,23 +103,23 @@ func TestBufferOperationsMethod(t *testing.T) {
t.Errorf("Expected key %s at position %d, got %s", expected[i], i, op.Key) t.Errorf("Expected key %s at position %d, got %s", expected[i], i, op.Key)
} }
} }
// Test with delete operations // Test with delete operations
b.Clear() b.Clear()
b.Put(keys[0], values[0]) b.Put(keys[0], values[0])
b.Delete(keys[1]) b.Delete(keys[1])
ops = b.Operations() ops = b.Operations()
if len(ops) != 2 { if len(ops) != 2 {
t.Errorf("Expected 2 operations, got %d", len(ops)) t.Errorf("Expected 2 operations, got %d", len(ops))
} }
// The first should be a delete for 'a', the second a put for 'c' // The first should be a delete for 'a', the second a put for 'c'
if !bytes.Equal(ops[0].Key, keys[1]) || !ops[0].IsDelete { if !bytes.Equal(ops[0].Key, keys[1]) || !ops[0].IsDelete {
t.Errorf("Expected delete operation for key %s, got %v", keys[1], ops[0]) t.Errorf("Expected delete operation for key %s, got %v", keys[1], ops[0])
} }
if !bytes.Equal(ops[1].Key, keys[0]) || ops[1].IsDelete { if !bytes.Equal(ops[1].Key, keys[0]) || ops[1].IsDelete {
t.Errorf("Expected put operation for key %s, got %v", keys[0], ops[1]) t.Errorf("Expected put operation for key %s, got %v", keys[0], ops[1])
} }
@ -127,7 +127,7 @@ func TestBufferOperationsMethod(t *testing.T) {
func TestBufferIterator(t *testing.T) { func TestBufferIterator(t *testing.T) {
b := NewBuffer() b := NewBuffer()
// Add multiple operations in non-sorted order // Add multiple operations in non-sorted order
keys := [][]byte{ keys := [][]byte{
[]byte("c"), []byte("c"),
@ -139,147 +139,147 @@ func TestBufferIterator(t *testing.T) {
[]byte("value_a"), []byte("value_a"),
[]byte("value_b"), []byte("value_b"),
} }
for i := range keys { for i := range keys {
b.Put(keys[i], values[i]) b.Put(keys[i], values[i])
} }
// Test iterator // Test iterator
it := b.NewIterator() it := b.NewIterator()
// Test Seek behavior // Test Seek behavior
if !it.Seek([]byte("b")) { if !it.Seek([]byte("b")) {
t.Error("Expected Seek('b') to return true") t.Error("Expected Seek('b') to return true")
} }
if !bytes.Equal(it.Key(), []byte("b")) { if !bytes.Equal(it.Key(), []byte("b")) {
t.Errorf("Expected key 'b', got %s", it.Key()) t.Errorf("Expected key 'b', got %s", it.Key())
} }
if !bytes.Equal(it.Value(), []byte("value_b")) { if !bytes.Equal(it.Value(), []byte("value_b")) {
t.Errorf("Expected value 'value_b', got %s", it.Value()) t.Errorf("Expected value 'value_b', got %s", it.Value())
} }
// Test seeking to a key that should exist // Test seeking to a key that should exist
if !it.Seek([]byte("a")) { if !it.Seek([]byte("a")) {
t.Error("Expected Seek('a') to return true") t.Error("Expected Seek('a') to return true")
} }
// Test seeking to a key that doesn't exist but is within range // Test seeking to a key that doesn't exist but is within range
if !it.Seek([]byte("bb")) { if !it.Seek([]byte("bb")) {
t.Error("Expected Seek('bb') to return true") t.Error("Expected Seek('bb') to return true")
} }
if !bytes.Equal(it.Key(), []byte("c")) { if !bytes.Equal(it.Key(), []byte("c")) {
t.Errorf("Expected key 'c' (next key after 'bb'), got %s", it.Key()) t.Errorf("Expected key 'c' (next key after 'bb'), got %s", it.Key())
} }
// Test seeking past the end // Test seeking past the end
if it.Seek([]byte("d")) { if it.Seek([]byte("d")) {
t.Error("Expected Seek('d') to return false") t.Error("Expected Seek('d') to return false")
} }
if it.Valid() { if it.Valid() {
t.Error("Expected iterator to be invalid after seeking past end") t.Error("Expected iterator to be invalid after seeking past end")
} }
// Test SeekToFirst // Test SeekToFirst
it.SeekToFirst() it.SeekToFirst()
if !it.Valid() { if !it.Valid() {
t.Error("Expected iterator to be valid after SeekToFirst") t.Error("Expected iterator to be valid after SeekToFirst")
} }
if !bytes.Equal(it.Key(), []byte("a")) { if !bytes.Equal(it.Key(), []byte("a")) {
t.Errorf("Expected first key to be 'a', got %s", it.Key()) t.Errorf("Expected first key to be 'a', got %s", it.Key())
} }
// Test Next // Test Next
if !it.Next() { if !it.Next() {
t.Error("Expected Next() to return true") t.Error("Expected Next() to return true")
} }
if !bytes.Equal(it.Key(), []byte("b")) { if !bytes.Equal(it.Key(), []byte("b")) {
t.Errorf("Expected second key to be 'b', got %s", it.Key()) t.Errorf("Expected second key to be 'b', got %s", it.Key())
} }
if !it.Next() { if !it.Next() {
t.Error("Expected Next() to return true for the third key") t.Error("Expected Next() to return true for the third key")
} }
if !bytes.Equal(it.Key(), []byte("c")) { if !bytes.Equal(it.Key(), []byte("c")) {
t.Errorf("Expected third key to be 'c', got %s", it.Key()) t.Errorf("Expected third key to be 'c', got %s", it.Key())
} }
// Should be at the end now // Should be at the end now
if it.Next() { if it.Next() {
t.Error("Expected Next() to return false after last key") t.Error("Expected Next() to return false after last key")
} }
if it.Valid() { if it.Valid() {
t.Error("Expected iterator to be invalid after iterating past end") t.Error("Expected iterator to be invalid after iterating past end")
} }
// Test SeekToLast // Test SeekToLast
it.SeekToLast() it.SeekToLast()
if !it.Valid() { if !it.Valid() {
t.Error("Expected iterator to be valid after SeekToLast") t.Error("Expected iterator to be valid after SeekToLast")
} }
if !bytes.Equal(it.Key(), []byte("c")) { if !bytes.Equal(it.Key(), []byte("c")) {
t.Errorf("Expected last key to be 'c', got %s", it.Key()) t.Errorf("Expected last key to be 'c', got %s", it.Key())
} }
// Test with delete operations // Test with delete operations
b.Clear() b.Clear()
b.Put([]byte("key1"), []byte("value1")) b.Put([]byte("key1"), []byte("value1"))
b.Delete([]byte("key2")) b.Delete([]byte("key2"))
it = b.NewIterator() it = b.NewIterator()
it.SeekToFirst() it.SeekToFirst()
// First key should be key1 // First key should be key1
if !bytes.Equal(it.Key(), []byte("key1")) { if !bytes.Equal(it.Key(), []byte("key1")) {
t.Errorf("Expected first key to be 'key1', got %s", it.Key()) t.Errorf("Expected first key to be 'key1', got %s", it.Key())
} }
if it.IsTombstone() { if it.IsTombstone() {
t.Error("Expected key1 not to be a tombstone") t.Error("Expected key1 not to be a tombstone")
} }
// Next key should be key2 // Next key should be key2
it.Next() it.Next()
if !bytes.Equal(it.Key(), []byte("key2")) { if !bytes.Equal(it.Key(), []byte("key2")) {
t.Errorf("Expected second key to be 'key2', got %s", it.Key()) t.Errorf("Expected second key to be 'key2', got %s", it.Key())
} }
if !it.IsTombstone() { if !it.IsTombstone() {
t.Error("Expected key2 to be a tombstone") t.Error("Expected key2 to be a tombstone")
} }
// Test empty iterator // Test empty iterator
b.Clear() b.Clear()
it = b.NewIterator() it = b.NewIterator()
if it.Valid() { if it.Valid() {
t.Error("Expected iterator to be invalid for empty buffer") t.Error("Expected iterator to be invalid for empty buffer")
} }
it.SeekToFirst() it.SeekToFirst()
if it.Valid() { if it.Valid() {
t.Error("Expected iterator to be invalid after SeekToFirst on empty buffer") t.Error("Expected iterator to be invalid after SeekToFirst on empty buffer")
} }
it.SeekToLast() it.SeekToLast()
if it.Valid() { if it.Valid() {
t.Error("Expected iterator to be invalid after SeekToLast on empty buffer") t.Error("Expected iterator to be invalid after SeekToLast on empty buffer")
} }
if it.Seek([]byte("any")) { if it.Seek([]byte("any")) {
t.Error("Expected Seek to return false on empty buffer") t.Error("Expected Seek to return false on empty buffer")
} }
} }

View File

@ -6,13 +6,13 @@ import "errors"
var ( var (
// ErrReadOnlyTransaction is returned when a write operation is attempted on a read-only transaction // ErrReadOnlyTransaction is returned when a write operation is attempted on a read-only transaction
ErrReadOnlyTransaction = errors.New("cannot write to a read-only transaction") ErrReadOnlyTransaction = errors.New("cannot write to a read-only transaction")
// ErrTransactionClosed is returned when an operation is attempted on a closed transaction // ErrTransactionClosed is returned when an operation is attempted on a closed transaction
ErrTransactionClosed = errors.New("transaction already committed or rolled back") ErrTransactionClosed = errors.New("transaction already committed or rolled back")
// ErrKeyNotFound is returned when a key doesn't exist // ErrKeyNotFound is returned when a key doesn't exist
ErrKeyNotFound = errors.New("key not found") ErrKeyNotFound = errors.New("key not found")
// ErrInvalidEngine is returned when an incompatible engine type is provided // ErrInvalidEngine is returned when an incompatible engine type is provided
ErrInvalidEngine = errors.New("invalid engine type") ErrInvalidEngine = errors.New("invalid engine type")
) )

View File

@ -68,4 +68,4 @@ type Registry interface {
// GracefulShutdown performs cleanup on shutdown // GracefulShutdown performs cleanup on shutdown
GracefulShutdown(ctx context.Context) error GracefulShutdown(ctx context.Context) error
} }

View File

@ -48,16 +48,16 @@ func (m *Manager) BeginTransaction(readOnly bool) (Transaction, error) {
// Create a new transaction // Create a new transaction
tx := &TransactionImpl{ tx := &TransactionImpl{
storage: m.storage, storage: m.storage,
mode: mode, mode: mode,
buffer: NewBuffer(), buffer: NewBuffer(),
rwLock: &m.txLock, rwLock: &m.txLock,
stats: m, stats: m,
} }
// Set transaction as active // Set transaction as active
tx.active.Store(true) tx.active.Store(true)
// Acquire appropriate lock // Acquire appropriate lock
if mode == ReadOnly { if mode == ReadOnly {
m.txLock.RLock() m.txLock.RLock()
@ -108,4 +108,4 @@ func (m *Manager) GetTransactionStats() map[string]interface{} {
stats["tx_active"] = active stats["tx_active"] = active
return stats return stats
} }

View File

@ -9,10 +9,10 @@ import (
func TestManagerBasics(t *testing.T) { func TestManagerBasics(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Test starting a read-only transaction // Test starting a read-only transaction
tx1, err := manager.BeginTransaction(true) tx1, err := manager.BeginTransaction(true)
if err != nil { if err != nil {
@ -21,14 +21,14 @@ func TestManagerBasics(t *testing.T) {
if !tx1.IsReadOnly() { if !tx1.IsReadOnly() {
t.Error("Transaction should be read-only") t.Error("Transaction should be read-only")
} }
// Commit the read-only transaction before starting a read-write one // Commit the read-only transaction before starting a read-write one
// to avoid deadlock (since our tests run in a single thread) // to avoid deadlock (since our tests run in a single thread)
err = tx1.Commit() err = tx1.Commit()
if err != nil { if err != nil {
t.Errorf("Unexpected error committing read-only transaction: %v", err) t.Errorf("Unexpected error committing read-only transaction: %v", err)
} }
// Test starting a read-write transaction // Test starting a read-write transaction
tx2, err := manager.BeginTransaction(false) tx2, err := manager.BeginTransaction(false)
if err != nil { if err != nil {
@ -37,28 +37,28 @@ func TestManagerBasics(t *testing.T) {
if tx2.IsReadOnly() { if tx2.IsReadOnly() {
t.Error("Transaction should be read-write") t.Error("Transaction should be read-write")
} }
// Commit the read-write transaction // Commit the read-write transaction
err = tx2.Commit() err = tx2.Commit()
if err != nil { if err != nil {
t.Errorf("Unexpected error committing read-write transaction: %v", err) t.Errorf("Unexpected error committing read-write transaction: %v", err)
} }
// Verify stats tracking // Verify stats tracking
stats := manager.GetTransactionStats() stats := manager.GetTransactionStats()
if stats["tx_started"] != uint64(2) { if stats["tx_started"] != uint64(2) {
t.Errorf("Expected 2 transactions started, got %v", stats["tx_started"]) t.Errorf("Expected 2 transactions started, got %v", stats["tx_started"])
} }
if stats["tx_completed"] != uint64(2) { if stats["tx_completed"] != uint64(2) {
t.Errorf("Expected 2 transactions completed, got %v", stats["tx_completed"]) t.Errorf("Expected 2 transactions completed, got %v", stats["tx_completed"])
} }
if stats["tx_aborted"] != uint64(0) { if stats["tx_aborted"] != uint64(0) {
t.Errorf("Expected 0 transactions aborted, got %v", stats["tx_aborted"]) t.Errorf("Expected 0 transactions aborted, got %v", stats["tx_aborted"])
} }
if stats["tx_active"] != uint64(0) { if stats["tx_active"] != uint64(0) {
t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"]) t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"])
} }
@ -67,36 +67,36 @@ func TestManagerBasics(t *testing.T) {
func TestManagerRollback(t *testing.T) { func TestManagerRollback(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Start a transaction and roll it back // Start a transaction and roll it back
tx, err := manager.BeginTransaction(false) tx, err := manager.BeginTransaction(false)
if err != nil { if err != nil {
t.Errorf("Unexpected error beginning transaction: %v", err) t.Errorf("Unexpected error beginning transaction: %v", err)
} }
err = tx.Rollback() err = tx.Rollback()
if err != nil { if err != nil {
t.Errorf("Unexpected error rolling back transaction: %v", err) t.Errorf("Unexpected error rolling back transaction: %v", err)
} }
// Verify stats tracking // Verify stats tracking
stats := manager.GetTransactionStats() stats := manager.GetTransactionStats()
if stats["tx_started"] != uint64(1) { if stats["tx_started"] != uint64(1) {
t.Errorf("Expected 1 transaction started, got %v", stats["tx_started"]) t.Errorf("Expected 1 transaction started, got %v", stats["tx_started"])
} }
if stats["tx_completed"] != uint64(0) { if stats["tx_completed"] != uint64(0) {
t.Errorf("Expected 0 transactions completed, got %v", stats["tx_completed"]) t.Errorf("Expected 0 transactions completed, got %v", stats["tx_completed"])
} }
if stats["tx_aborted"] != uint64(1) { if stats["tx_aborted"] != uint64(1) {
t.Errorf("Expected 1 transaction aborted, got %v", stats["tx_aborted"]) t.Errorf("Expected 1 transaction aborted, got %v", stats["tx_aborted"])
} }
if stats["tx_active"] != uint64(0) { if stats["tx_active"] != uint64(0) {
t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"]) t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"])
} }
@ -105,69 +105,69 @@ func TestManagerRollback(t *testing.T) {
func TestConcurrentTransactions(t *testing.T) { func TestConcurrentTransactions(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Initialize some data // Initialize some data
storage.Put([]byte("counter"), []byte{0}) storage.Put([]byte("counter"), []byte{0})
// Rather than using concurrency which can cause flaky tests, // Rather than using concurrency which can cause flaky tests,
// we'll execute transactions sequentially but simulate the same behavior // we'll execute transactions sequentially but simulate the same behavior
numTransactions := 10 numTransactions := 10
for i := 0; i < numTransactions; i++ { for i := 0; i < numTransactions; i++ {
// Start a read-write transaction // Start a read-write transaction
tx, err := manager.BeginTransaction(false) tx, err := manager.BeginTransaction(false)
if err != nil { if err != nil {
t.Fatalf("Failed to begin transaction %d: %v", i, err) t.Fatalf("Failed to begin transaction %d: %v", i, err)
} }
// Read counter value // Read counter value
counterValue, err := tx.Get([]byte("counter")) counterValue, err := tx.Get([]byte("counter"))
if err != nil { if err != nil {
t.Fatalf("Failed to get counter in transaction %d: %v", i, err) t.Fatalf("Failed to get counter in transaction %d: %v", i, err)
} }
// Increment counter value // Increment counter value
newValue := []byte{counterValue[0] + 1} newValue := []byte{counterValue[0] + 1}
// Write new counter value // Write new counter value
err = tx.Put([]byte("counter"), newValue) err = tx.Put([]byte("counter"), newValue)
if err != nil { if err != nil {
t.Fatalf("Failed to update counter in transaction %d: %v", i, err) t.Fatalf("Failed to update counter in transaction %d: %v", i, err)
} }
// Commit transaction // Commit transaction
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
t.Fatalf("Failed to commit transaction %d: %v", i, err) t.Fatalf("Failed to commit transaction %d: %v", i, err)
} }
} }
// Verify final counter value // Verify final counter value
finalValue, err := storage.Get([]byte("counter")) finalValue, err := storage.Get([]byte("counter"))
if err != nil { if err != nil {
t.Errorf("Error getting final counter value: %v", err) t.Errorf("Error getting final counter value: %v", err)
} }
// Counter should have been incremented numTransactions times // Counter should have been incremented numTransactions times
expectedValue := byte(numTransactions) expectedValue := byte(numTransactions)
if finalValue[0] != expectedValue { if finalValue[0] != expectedValue {
t.Errorf("Expected counter value %d, got %d", expectedValue, finalValue[0]) t.Errorf("Expected counter value %d, got %d", expectedValue, finalValue[0])
} }
// Verify that all transactions completed // Verify that all transactions completed
stats := manager.GetTransactionStats() stats := manager.GetTransactionStats()
if stats["tx_started"] != uint64(numTransactions) { if stats["tx_started"] != uint64(numTransactions) {
t.Errorf("Expected %d transactions started, got %v", numTransactions, stats["tx_started"]) t.Errorf("Expected %d transactions started, got %v", numTransactions, stats["tx_started"])
} }
if stats["tx_completed"] != uint64(numTransactions) { if stats["tx_completed"] != uint64(numTransactions) {
t.Errorf("Expected %d transactions completed, got %v", numTransactions, stats["tx_completed"]) t.Errorf("Expected %d transactions completed, got %v", numTransactions, stats["tx_completed"])
} }
if stats["tx_active"] != uint64(0) { if stats["tx_active"] != uint64(0) {
t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"]) t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"])
} }
@ -176,45 +176,45 @@ func TestConcurrentTransactions(t *testing.T) {
func TestReadOnlyConcurrency(t *testing.T) { func TestReadOnlyConcurrency(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Initialize some data // Initialize some data
storage.Put([]byte("key1"), []byte("value1")) storage.Put([]byte("key1"), []byte("value1"))
// Create a WaitGroup to synchronize goroutines // Create a WaitGroup to synchronize goroutines
var wg sync.WaitGroup var wg sync.WaitGroup
// Number of concurrent read transactions to run // Number of concurrent read transactions to run
numReaders := 5 numReaders := 5
wg.Add(numReaders) wg.Add(numReaders)
// Channel to collect errors // Channel to collect errors
errors := make(chan error, numReaders) errors := make(chan error, numReaders)
// Start multiple read transactions concurrently // Start multiple read transactions concurrently
for i := 0; i < numReaders; i++ { for i := 0; i < numReaders; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
// Start a read-only transaction // Start a read-only transaction
tx, err := manager.BeginTransaction(true) tx, err := manager.BeginTransaction(true)
if err != nil { if err != nil {
errors <- err errors <- err
return return
} }
// Read data // Read data
_, err = tx.Get([]byte("key1")) _, err = tx.Get([]byte("key1"))
if err != nil { if err != nil {
errors <- err errors <- err
return return
} }
// Simulate some processing time // Simulate some processing time
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Commit transaction // Commit transaction
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
@ -223,28 +223,28 @@ func TestReadOnlyConcurrency(t *testing.T) {
} }
}() }()
} }
// Wait for all readers to finish // Wait for all readers to finish
wg.Wait() wg.Wait()
close(errors) close(errors)
// Check for errors // Check for errors
for err := range errors { for err := range errors {
t.Errorf("Error in concurrent read transaction: %v", err) t.Errorf("Error in concurrent read transaction: %v", err)
} }
// Verify that all transactions completed // Verify that all transactions completed
stats := manager.GetTransactionStats() stats := manager.GetTransactionStats()
if stats["tx_started"] != uint64(numReaders) { if stats["tx_started"] != uint64(numReaders) {
t.Errorf("Expected %d transactions started, got %v", numReaders, stats["tx_started"]) t.Errorf("Expected %d transactions started, got %v", numReaders, stats["tx_started"])
} }
if stats["tx_completed"] != uint64(numReaders) { if stats["tx_completed"] != uint64(numReaders) {
t.Errorf("Expected %d transactions completed, got %v", numReaders, stats["tx_completed"]) t.Errorf("Expected %d transactions completed, got %v", numReaders, stats["tx_completed"])
} }
if stats["tx_active"] != uint64(0) { if stats["tx_active"] != uint64(0) {
t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"]) t.Errorf("Expected 0 active transactions, got %v", stats["tx_active"])
} }
} }

View File

@ -85,7 +85,7 @@ func (s *MemoryStorage) newIterator(startKey, endKey []byte) *MemoryIterator {
keys := make([][]byte, 0, len(s.data)) keys := make([][]byte, 0, len(s.data))
for k := range s.data { for k := range s.data {
keyBytes := []byte(k) keyBytes := []byte(k)
// Apply range filtering if specified // Apply range filtering if specified
if startKey != nil && bytes.Compare(keyBytes, startKey) < 0 { if startKey != nil && bytes.Compare(keyBytes, startKey) < 0 {
continue continue
@ -93,7 +93,7 @@ func (s *MemoryStorage) newIterator(startKey, endKey []byte) *MemoryIterator {
if endKey != nil && bytes.Compare(keyBytes, endKey) >= 0 { if endKey != nil && bytes.Compare(keyBytes, endKey) >= 0 {
continue continue
} }
keys = append(keys, keyBytes) keys = append(keys, keyBytes)
} }
@ -162,7 +162,7 @@ func (it *MemoryIterator) Next() bool {
it.SeekToFirst() it.SeekToFirst()
return it.Valid() return it.Valid()
} }
if it.position >= len(it.keys)-1 { if it.position >= len(it.keys)-1 {
it.position = -1 it.position = -1
return false return false
@ -202,14 +202,14 @@ func (it *MemoryIterator) IsTombstone() bool {
func (s *MemoryStorage) Put(key, value []byte) { func (s *MemoryStorage) Put(key, value []byte) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Make a copy of the key and value // Make a copy of the key and value
keyCopy := make([]byte, len(key)) keyCopy := make([]byte, len(key))
copy(keyCopy, key) copy(keyCopy, key)
valueCopy := make([]byte, len(value)) valueCopy := make([]byte, len(value))
copy(valueCopy, value) copy(valueCopy, value)
s.data[string(keyCopy)] = valueCopy s.data[string(keyCopy)] = valueCopy
} }
@ -217,7 +217,7 @@ func (s *MemoryStorage) Put(key, value []byte) {
func (s *MemoryStorage) Delete(key []byte) { func (s *MemoryStorage) Delete(key []byte) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
delete(s.data, string(key)) delete(s.data, string(key))
} }
@ -225,6 +225,6 @@ func (s *MemoryStorage) Delete(key []byte) {
func (s *MemoryStorage) Size() int { func (s *MemoryStorage) Size() int {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return len(s.data) return len(s.data)
} }

View File

@ -79,4 +79,4 @@ func (s *StatsCollectorMock) IncrementTxCompleted() {
// IncrementTxAborted increments the aborted transaction counter // IncrementTxAborted increments the aborted transaction counter
func (s *StatsCollectorMock) IncrementTxAborted() { func (s *StatsCollectorMock) IncrementTxAborted() {
s.txAborted.Add(1) s.txAborted.Add(1)
} }

View File

@ -305,4 +305,4 @@ func (r *RegistryImpl) GracefulShutdown(ctx context.Context) error {
r.connectionTxs = make(map[string]map[string]struct{}) r.connectionTxs = make(map[string]map[string]struct{})
return lastErr return lastErr
} }

View File

@ -9,46 +9,46 @@ import (
func TestRegistryBasicOperations(t *testing.T) { func TestRegistryBasicOperations(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Create a registry // Create a registry
registry := NewRegistry() registry := NewRegistry()
// Test creating a new transaction // Test creating a new transaction
txID, err := registry.Begin(context.Background(), manager, true) txID, err := registry.Begin(context.Background(), manager, true)
if err != nil { if err != nil {
t.Errorf("Unexpected error beginning transaction: %v", err) t.Errorf("Unexpected error beginning transaction: %v", err)
} }
if txID == "" { if txID == "" {
t.Error("Expected non-empty transaction ID") t.Error("Expected non-empty transaction ID")
} }
// Test getting a transaction // Test getting a transaction
tx, exists := registry.Get(txID) tx, exists := registry.Get(txID)
if !exists { if !exists {
t.Errorf("Expected to find transaction %s", txID) t.Errorf("Expected to find transaction %s", txID)
} }
if tx == nil { if tx == nil {
t.Error("Expected non-nil transaction") t.Error("Expected non-nil transaction")
} }
if !tx.IsReadOnly() { if !tx.IsReadOnly() {
t.Error("Expected read-only transaction") t.Error("Expected read-only transaction")
} }
// Test operations on the transaction // Test operations on the transaction
_, err = tx.Get([]byte("test_key")) _, err = tx.Get([]byte("test_key"))
if err != nil && err != ErrKeyNotFound { if err != nil && err != ErrKeyNotFound {
t.Errorf("Unexpected error in transaction operation: %v", err) t.Errorf("Unexpected error in transaction operation: %v", err)
} }
// Remove the transaction from the registry // Remove the transaction from the registry
registry.Remove(txID) registry.Remove(txID)
// Transaction should no longer be in the registry // Transaction should no longer be in the registry
_, exists = registry.Get(txID) _, exists = registry.Get(txID)
if exists { if exists {
@ -59,49 +59,49 @@ func TestRegistryBasicOperations(t *testing.T) {
func TestRegistryConnectionCleanup(t *testing.T) { func TestRegistryConnectionCleanup(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Create a registry // Create a registry
registry := NewRegistry() registry := NewRegistry()
// Create context with connection ID // Create context with connection ID
ctx := context.WithValue(context.Background(), "peer", "connection1") ctx := context.WithValue(context.Background(), "peer", "connection1")
// Begin a read-only transaction first to avoid deadlock // Begin a read-only transaction first to avoid deadlock
txID1, err := registry.Begin(ctx, manager, true) txID1, err := registry.Begin(ctx, manager, true)
if err != nil { if err != nil {
t.Errorf("Unexpected error beginning transaction: %v", err) t.Errorf("Unexpected error beginning transaction: %v", err)
} }
// Get and commit the first transaction before starting the second // Get and commit the first transaction before starting the second
tx1, exists := registry.Get(txID1) tx1, exists := registry.Get(txID1)
if exists && tx1 != nil { if exists && tx1 != nil {
tx1.Commit() tx1.Commit()
} }
// Now begin a read-write transaction // Now begin a read-write transaction
txID2, err := registry.Begin(ctx, manager, false) txID2, err := registry.Begin(ctx, manager, false)
if err != nil { if err != nil {
t.Errorf("Unexpected error beginning transaction: %v", err) t.Errorf("Unexpected error beginning transaction: %v", err)
} }
// Verify transactions exist // Verify transactions exist
_, exists1 := registry.Get(txID1) _, exists1 := registry.Get(txID1)
_, exists2 := registry.Get(txID2) _, exists2 := registry.Get(txID2)
if !exists1 || !exists2 { if !exists1 || !exists2 {
t.Error("Expected both transactions to exist in registry") t.Error("Expected both transactions to exist in registry")
} }
// Clean up the connection // Clean up the connection
registry.CleanupConnection("connection1") registry.CleanupConnection("connection1")
// Verify transactions are removed // Verify transactions are removed
_, exists1 = registry.Get(txID1) _, exists1 = registry.Get(txID1)
_, exists2 = registry.Get(txID2) _, exists2 = registry.Get(txID2)
if exists1 || exists2 { if exists1 || exists2 {
t.Error("Expected all transactions to be removed after connection cleanup") t.Error("Expected all transactions to be removed after connection cleanup")
} }
@ -110,25 +110,25 @@ func TestRegistryConnectionCleanup(t *testing.T) {
func TestRegistryGracefulShutdown(t *testing.T) { func TestRegistryGracefulShutdown(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Create a registry // Create a registry
registry := NewRegistry() registry := NewRegistry()
// Begin a read-write transaction // Begin a read-write transaction
txID, err := registry.Begin(context.Background(), manager, false) txID, err := registry.Begin(context.Background(), manager, false)
if err != nil { if err != nil {
t.Errorf("Unexpected error beginning transaction: %v", err) t.Errorf("Unexpected error beginning transaction: %v", err)
} }
// Verify transaction exists // Verify transaction exists
_, exists := registry.Get(txID) _, exists := registry.Get(txID)
if !exists { if !exists {
t.Error("Expected transaction to exist in registry") t.Error("Expected transaction to exist in registry")
} }
// Perform graceful shutdown // Perform graceful shutdown
err = registry.GracefulShutdown(context.Background()) err = registry.GracefulShutdown(context.Background())
if err != nil { if err != nil {
@ -136,7 +136,7 @@ func TestRegistryGracefulShutdown(t *testing.T) {
// We'll just log it rather than failing the test // We'll just log it rather than failing the test
t.Logf("Note: Error during graceful shutdown (expected): %v", err) t.Logf("Note: Error during graceful shutdown (expected): %v", err)
} }
// Verify transaction is removed regardless of error // Verify transaction is removed regardless of error
_, exists = registry.Get(txID) _, exists = registry.Get(txID)
if exists { if exists {
@ -147,49 +147,49 @@ func TestRegistryGracefulShutdown(t *testing.T) {
func TestRegistryConcurrentOperations(t *testing.T) { func TestRegistryConcurrentOperations(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
// Create a transaction manager // Create a transaction manager
manager := NewManager(storage, statsCollector) manager := NewManager(storage, statsCollector)
// Create a registry // Create a registry
registry := NewRegistry() registry := NewRegistry()
// Instead of concurrent operations which can cause deadlocks in tests, // Instead of concurrent operations which can cause deadlocks in tests,
// we'll perform operations sequentially // we'll perform operations sequentially
numTransactions := 5 numTransactions := 5
// Track transaction IDs // Track transaction IDs
var txIDs []string var txIDs []string
// Create multiple transactions sequentially // Create multiple transactions sequentially
for i := 0; i < numTransactions; i++ { for i := 0; i < numTransactions; i++ {
// Create a context with a unique connection ID // Create a context with a unique connection ID
connID := fmt.Sprintf("connection-%d", i) connID := fmt.Sprintf("connection-%d", i)
ctx := context.WithValue(context.Background(), "peer", connID) ctx := context.WithValue(context.Background(), "peer", connID)
// Begin a transaction // Begin a transaction
txID, err := registry.Begin(ctx, manager, true) // Use read-only transactions to avoid locks txID, err := registry.Begin(ctx, manager, true) // Use read-only transactions to avoid locks
if err != nil { if err != nil {
t.Errorf("Failed to begin transaction %d: %v", i, err) t.Errorf("Failed to begin transaction %d: %v", i, err)
continue continue
} }
txIDs = append(txIDs, txID) txIDs = append(txIDs, txID)
// Get the transaction // Get the transaction
tx, exists := registry.Get(txID) tx, exists := registry.Get(txID)
if !exists { if !exists {
t.Errorf("Transaction %s not found", txID) t.Errorf("Transaction %s not found", txID)
continue continue
} }
// Test read operation // Test read operation
_, err = tx.Get([]byte("test_key")) _, err = tx.Get([]byte("test_key"))
if err != nil && err != ErrKeyNotFound { if err != nil && err != ErrKeyNotFound {
t.Errorf("Unexpected error in transaction operation: %v", err) t.Errorf("Unexpected error in transaction operation: %v", err)
} }
} }
// Clean up transactions // Clean up transactions
for _, txID := range txIDs { for _, txID := range txIDs {
tx, exists := registry.Get(txID) tx, exists := registry.Get(txID)
@ -201,7 +201,7 @@ func TestRegistryConcurrentOperations(t *testing.T) {
registry.Remove(txID) registry.Remove(txID)
} }
} }
// Verify all transactions are removed // Verify all transactions are removed
for _, txID := range txIDs { for _, txID := range txIDs {
_, exists := registry.Get(txID) _, exists := registry.Get(txID)
@ -209,4 +209,4 @@ func TestRegistryConcurrentOperations(t *testing.T) {
t.Errorf("Expected transaction %s to be removed", txID) t.Errorf("Expected transaction %s to be removed", txID)
} }
} }
} }

View File

@ -10,13 +10,13 @@ import (
type StorageBackend interface { type StorageBackend interface {
// Get retrieves a value for the given key // Get retrieves a value for the given key
Get(key []byte) ([]byte, error) Get(key []byte) ([]byte, error)
// ApplyBatch applies a batch of operations atomically // ApplyBatch applies a batch of operations atomically
ApplyBatch(entries []*wal.Entry) error ApplyBatch(entries []*wal.Entry) error
// GetIterator returns an iterator over all keys // GetIterator returns an iterator over all keys
GetIterator() (iterator.Iterator, error) GetIterator() (iterator.Iterator, error)
// GetRangeIterator returns an iterator limited to a specific key range // GetRangeIterator returns an iterator limited to a specific key range
GetRangeIterator(startKey, endKey []byte) (iterator.Iterator, error) GetRangeIterator(startKey, endKey []byte) (iterator.Iterator, error)
} }

View File

@ -14,28 +14,28 @@ import (
type TransactionImpl struct { type TransactionImpl struct {
// Reference to the storage backend // Reference to the storage backend
storage StorageBackend storage StorageBackend
// Transaction mode (ReadOnly or ReadWrite) // Transaction mode (ReadOnly or ReadWrite)
mode TransactionMode mode TransactionMode
// Buffer for transaction operations // Buffer for transaction operations
buffer *Buffer buffer *Buffer
// Tracks if the transaction is still active // Tracks if the transaction is still active
active atomic.Bool active atomic.Bool
// For read-only transactions, tracks if we have a read lock // For read-only transactions, tracks if we have a read lock
hasReadLock atomic.Bool hasReadLock atomic.Bool
// For read-write transactions, tracks if we have the write lock // For read-write transactions, tracks if we have the write lock
hasWriteLock atomic.Bool hasWriteLock atomic.Bool
// Lock for transaction-level synchronization // Lock for transaction-level synchronization
mu sync.Mutex mu sync.Mutex
// RWLock for transaction isolation // RWLock for transaction isolation
rwLock *sync.RWMutex rwLock *sync.RWMutex
// Stats collector // Stats collector
stats StatsCollector stats StatsCollector
} }
@ -51,12 +51,12 @@ func (tx *TransactionImpl) Get(key []byte) ([]byte, error) {
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Check if transaction is still active // Check if transaction is still active
if !tx.active.Load() { if !tx.active.Load() {
return nil, ErrTransactionClosed return nil, ErrTransactionClosed
} }
// First check the transaction buffer for any pending changes // First check the transaction buffer for any pending changes
if val, found := tx.buffer.Get(key); found { if val, found := tx.buffer.Get(key); found {
if val == nil { if val == nil {
@ -65,7 +65,7 @@ func (tx *TransactionImpl) Get(key []byte) ([]byte, error) {
} }
return val, nil return val, nil
} }
// Not in the buffer, get from the underlying storage // Not in the buffer, get from the underlying storage
return tx.storage.Get(key) return tx.storage.Get(key)
} }
@ -75,17 +75,17 @@ func (tx *TransactionImpl) Put(key, value []byte) error {
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Check if transaction is still active // Check if transaction is still active
if !tx.active.Load() { if !tx.active.Load() {
return ErrTransactionClosed return ErrTransactionClosed
} }
// Check if transaction is read-only // Check if transaction is read-only
if tx.mode == ReadOnly { if tx.mode == ReadOnly {
return ErrReadOnlyTransaction return ErrReadOnlyTransaction
} }
// Buffer the change - it will be applied on commit // Buffer the change - it will be applied on commit
tx.buffer.Put(key, value) tx.buffer.Put(key, value)
return nil return nil
@ -96,17 +96,17 @@ func (tx *TransactionImpl) Delete(key []byte) error {
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Check if transaction is still active // Check if transaction is still active
if !tx.active.Load() { if !tx.active.Load() {
return ErrTransactionClosed return ErrTransactionClosed
} }
// Check if transaction is read-only // Check if transaction is read-only
if tx.mode == ReadOnly { if tx.mode == ReadOnly {
return ErrReadOnlyTransaction return ErrReadOnlyTransaction
} }
// Buffer the deletion - it will be applied on commit // Buffer the deletion - it will be applied on commit
tx.buffer.Delete(key) tx.buffer.Delete(key)
return nil return nil
@ -117,28 +117,28 @@ func (tx *TransactionImpl) NewIterator() iterator.Iterator {
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Check if transaction is still active // Check if transaction is still active
if !tx.active.Load() { if !tx.active.Load() {
// Return an empty iterator // Return an empty iterator
return &emptyIterator{} return &emptyIterator{}
} }
// Get the storage iterator // Get the storage iterator
storageIter, err := tx.storage.GetIterator() storageIter, err := tx.storage.GetIterator()
if err != nil { if err != nil {
// If we can't get a storage iterator, return a buffer-only iterator // If we can't get a storage iterator, return a buffer-only iterator
return tx.buffer.NewIterator() return tx.buffer.NewIterator()
} }
// If there are no changes in the buffer, just use the storage's iterator // If there are no changes in the buffer, just use the storage's iterator
if tx.buffer.Size() == 0 { if tx.buffer.Size() == 0 {
return storageIter return storageIter
} }
// Merge buffer and storage iterators // Merge buffer and storage iterators
bufferIter := tx.buffer.NewIterator() bufferIter := tx.buffer.NewIterator()
// Use composite hierarchical iterator // Use composite hierarchical iterator
return composite.NewHierarchicalIterator([]iterator.Iterator{bufferIter, storageIter}) return composite.NewHierarchicalIterator([]iterator.Iterator{bufferIter, storageIter})
} }
@ -148,13 +148,13 @@ func (tx *TransactionImpl) NewRangeIterator(startKey, endKey []byte) iterator.It
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Check if transaction is still active // Check if transaction is still active
if !tx.active.Load() { if !tx.active.Load() {
// Return an empty iterator // Return an empty iterator
return &emptyIterator{} return &emptyIterator{}
} }
// Get the storage iterator for the range // Get the storage iterator for the range
storageIter, err := tx.storage.GetRangeIterator(startKey, endKey) storageIter, err := tx.storage.GetRangeIterator(startKey, endKey)
if err != nil { if err != nil {
@ -162,16 +162,16 @@ func (tx *TransactionImpl) NewRangeIterator(startKey, endKey []byte) iterator.It
bufferIter := tx.buffer.NewIterator() bufferIter := tx.buffer.NewIterator()
return bounded.NewBoundedIterator(bufferIter, startKey, endKey) return bounded.NewBoundedIterator(bufferIter, startKey, endKey)
} }
// If there are no changes in the buffer, just use the storage's range iterator // If there are no changes in the buffer, just use the storage's range iterator
if tx.buffer.Size() == 0 { if tx.buffer.Size() == 0 {
return storageIter return storageIter
} }
// Create a bounded buffer iterator // Create a bounded buffer iterator
bufferIter := tx.buffer.NewIterator() bufferIter := tx.buffer.NewIterator()
boundedBufferIter := bounded.NewBoundedIterator(bufferIter, startKey, endKey) boundedBufferIter := bounded.NewBoundedIterator(bufferIter, startKey, endKey)
// Merge the bounded buffer iterator with the storage range iterator // Merge the bounded buffer iterator with the storage range iterator
return composite.NewHierarchicalIterator([]iterator.Iterator{boundedBufferIter, storageIter}) return composite.NewHierarchicalIterator([]iterator.Iterator{boundedBufferIter, storageIter})
} }
@ -193,34 +193,34 @@ func (tx *TransactionImpl) Commit() error {
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Only proceed if the transaction is still active // Only proceed if the transaction is still active
if !tx.active.CompareAndSwap(true, false) { if !tx.active.CompareAndSwap(true, false) {
return ErrTransactionClosed return ErrTransactionClosed
} }
var err error var err error
// For read-only transactions, just release the read lock // For read-only transactions, just release the read lock
if tx.mode == ReadOnly { if tx.mode == ReadOnly {
tx.releaseReadLock() tx.releaseReadLock()
// Track transaction completion // Track transaction completion
if tx.stats != nil { if tx.stats != nil {
tx.stats.IncrementTxCompleted() tx.stats.IncrementTxCompleted()
} }
return nil return nil
} }
// For read-write transactions, apply the changes // For read-write transactions, apply the changes
if tx.buffer.Size() > 0 { if tx.buffer.Size() > 0 {
// Get operations from the buffer // Get operations from the buffer
ops := tx.buffer.Operations() ops := tx.buffer.Operations()
// Create a batch for all operations // Create a batch for all operations
walBatch := make([]*wal.Entry, 0, len(ops)) walBatch := make([]*wal.Entry, 0, len(ops))
// Build WAL entries for each operation // Build WAL entries for each operation
for _, op := range ops { for _, op := range ops {
if op.IsDelete { if op.IsDelete {
@ -238,19 +238,19 @@ func (tx *TransactionImpl) Commit() error {
}) })
} }
} }
// Apply the batch atomically // Apply the batch atomically
err = tx.storage.ApplyBatch(walBatch) err = tx.storage.ApplyBatch(walBatch)
} }
// Release the write lock // Release the write lock
tx.releaseWriteLock() tx.releaseWriteLock()
// Track transaction completion // Track transaction completion
if tx.stats != nil { if tx.stats != nil {
tx.stats.IncrementTxCompleted() tx.stats.IncrementTxCompleted()
} }
return err return err
} }
@ -259,27 +259,27 @@ func (tx *TransactionImpl) Rollback() error {
// Use transaction lock for consistent view // Use transaction lock for consistent view
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
// Only proceed if the transaction is still active // Only proceed if the transaction is still active
if !tx.active.CompareAndSwap(true, false) { if !tx.active.CompareAndSwap(true, false) {
return ErrTransactionClosed return ErrTransactionClosed
} }
// Clear the buffer // Clear the buffer
tx.buffer.Clear() tx.buffer.Clear()
// Release locks based on transaction mode // Release locks based on transaction mode
if tx.mode == ReadOnly { if tx.mode == ReadOnly {
tx.releaseReadLock() tx.releaseReadLock()
} else { } else {
tx.releaseWriteLock() tx.releaseWriteLock()
} }
// Track transaction abort // Track transaction abort
if tx.stats != nil { if tx.stats != nil {
tx.stats.IncrementTxAborted() tx.stats.IncrementTxAborted()
} }
return nil return nil
} }
@ -300,4 +300,4 @@ func (tx *TransactionImpl) releaseWriteLock() {
if tx.hasWriteLock.CompareAndSwap(true, false) { if tx.hasWriteLock.CompareAndSwap(true, false) {
tx.rwLock.Unlock() tx.rwLock.Unlock()
} }
} }

View File

@ -10,25 +10,25 @@ func TestTransactionBasicOperations(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
rwLock := &sync.RWMutex{} rwLock := &sync.RWMutex{}
// Prepare some initial data // Prepare some initial data
storage.Put([]byte("existing1"), []byte("value1")) storage.Put([]byte("existing1"), []byte("value1"))
storage.Put([]byte("existing2"), []byte("value2")) storage.Put([]byte("existing2"), []byte("value2"))
// Create a transaction // Create a transaction
tx := &TransactionImpl{ tx := &TransactionImpl{
storage: storage, storage: storage,
mode: ReadWrite, mode: ReadWrite,
buffer: NewBuffer(), buffer: NewBuffer(),
rwLock: rwLock, rwLock: rwLock,
stats: statsCollector, stats: statsCollector,
} }
tx.active.Store(true) tx.active.Store(true)
// Actually acquire the write lock before setting the flag // Actually acquire the write lock before setting the flag
rwLock.Lock() rwLock.Lock()
tx.hasWriteLock.Store(true) tx.hasWriteLock.Store(true)
// Test Get existing key // Test Get existing key
value, err := tx.Get([]byte("existing1")) value, err := tx.Get([]byte("existing1"))
if err != nil { if err != nil {
@ -37,19 +37,19 @@ func TestTransactionBasicOperations(t *testing.T) {
if !bytes.Equal(value, []byte("value1")) { if !bytes.Equal(value, []byte("value1")) {
t.Errorf("Expected value 'value1', got %s", value) t.Errorf("Expected value 'value1', got %s", value)
} }
// Test Get non-existing key // Test Get non-existing key
_, err = tx.Get([]byte("nonexistent")) _, err = tx.Get([]byte("nonexistent"))
if err == nil || err != ErrKeyNotFound { if err == nil || err != ErrKeyNotFound {
t.Errorf("Expected ErrKeyNotFound for nonexistent key, got %v", err) t.Errorf("Expected ErrKeyNotFound for nonexistent key, got %v", err)
} }
// Test Put and then Get from buffer // Test Put and then Get from buffer
err = tx.Put([]byte("key1"), []byte("new_value1")) err = tx.Put([]byte("key1"), []byte("new_value1"))
if err != nil { if err != nil {
t.Errorf("Unexpected error putting key: %v", err) t.Errorf("Unexpected error putting key: %v", err)
} }
value, err = tx.Get([]byte("key1")) value, err = tx.Get([]byte("key1"))
if err != nil { if err != nil {
t.Errorf("Unexpected error getting key from buffer: %v", err) t.Errorf("Unexpected error getting key from buffer: %v", err)
@ -57,13 +57,13 @@ func TestTransactionBasicOperations(t *testing.T) {
if !bytes.Equal(value, []byte("new_value1")) { if !bytes.Equal(value, []byte("new_value1")) {
t.Errorf("Expected buffer value 'new_value1', got %s", value) t.Errorf("Expected buffer value 'new_value1', got %s", value)
} }
// Test overwriting existing key // Test overwriting existing key
err = tx.Put([]byte("existing1"), []byte("updated_value1")) err = tx.Put([]byte("existing1"), []byte("updated_value1"))
if err != nil { if err != nil {
t.Errorf("Unexpected error updating key: %v", err) t.Errorf("Unexpected error updating key: %v", err)
} }
value, err = tx.Get([]byte("existing1")) value, err = tx.Get([]byte("existing1"))
if err != nil { if err != nil {
t.Errorf("Unexpected error getting updated key: %v", err) t.Errorf("Unexpected error getting updated key: %v", err)
@ -71,50 +71,50 @@ func TestTransactionBasicOperations(t *testing.T) {
if !bytes.Equal(value, []byte("updated_value1")) { if !bytes.Equal(value, []byte("updated_value1")) {
t.Errorf("Expected updated value 'updated_value1', got %s", value) t.Errorf("Expected updated value 'updated_value1', got %s", value)
} }
// Test Delete operation // Test Delete operation
err = tx.Delete([]byte("existing2")) err = tx.Delete([]byte("existing2"))
if err != nil { if err != nil {
t.Errorf("Unexpected error deleting key: %v", err) t.Errorf("Unexpected error deleting key: %v", err)
} }
_, err = tx.Get([]byte("existing2")) _, err = tx.Get([]byte("existing2"))
if err == nil || err != ErrKeyNotFound { if err == nil || err != ErrKeyNotFound {
t.Errorf("Expected ErrKeyNotFound for deleted key, got %v", err) t.Errorf("Expected ErrKeyNotFound for deleted key, got %v", err)
} }
// Test operations on closed transaction // Test operations on closed transaction
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
t.Errorf("Unexpected error committing transaction: %v", err) t.Errorf("Unexpected error committing transaction: %v", err)
} }
// After commit, the transaction should be closed // After commit, the transaction should be closed
_, err = tx.Get([]byte("key1")) _, err = tx.Get([]byte("key1"))
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
t.Errorf("Expected ErrTransactionClosed, got %v", err) t.Errorf("Expected ErrTransactionClosed, got %v", err)
} }
err = tx.Put([]byte("key2"), []byte("value2")) err = tx.Put([]byte("key2"), []byte("value2"))
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
t.Errorf("Expected ErrTransactionClosed, got %v", err) t.Errorf("Expected ErrTransactionClosed, got %v", err)
} }
err = tx.Delete([]byte("key1")) err = tx.Delete([]byte("key1"))
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
t.Errorf("Expected ErrTransactionClosed, got %v", err) t.Errorf("Expected ErrTransactionClosed, got %v", err)
} }
err = tx.Commit() err = tx.Commit()
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
t.Errorf("Expected ErrTransactionClosed for second commit, got %v", err) t.Errorf("Expected ErrTransactionClosed for second commit, got %v", err)
} }
err = tx.Rollback() err = tx.Rollback()
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
t.Errorf("Expected ErrTransactionClosed for rollback after commit, got %v", err) t.Errorf("Expected ErrTransactionClosed for rollback after commit, got %v", err)
} }
// Verify committed changes exist in storage // Verify committed changes exist in storage
val, err := storage.Get([]byte("key1")) val, err := storage.Get([]byte("key1"))
if err != nil { if err != nil {
@ -123,7 +123,7 @@ func TestTransactionBasicOperations(t *testing.T) {
if !bytes.Equal(val, []byte("new_value1")) { if !bytes.Equal(val, []byte("new_value1")) {
t.Errorf("Expected value 'new_value1' in storage, got %s", val) t.Errorf("Expected value 'new_value1' in storage, got %s", val)
} }
val, err = storage.Get([]byte("existing1")) val, err = storage.Get([]byte("existing1"))
if err != nil { if err != nil {
t.Errorf("Expected existing1 to exist in storage with updated value, got error: %v", err) t.Errorf("Expected existing1 to exist in storage with updated value, got error: %v", err)
@ -131,7 +131,7 @@ func TestTransactionBasicOperations(t *testing.T) {
if !bytes.Equal(val, []byte("updated_value1")) { if !bytes.Equal(val, []byte("updated_value1")) {
t.Errorf("Expected value 'updated_value1' in storage, got %s", val) t.Errorf("Expected value 'updated_value1' in storage, got %s", val)
} }
_, err = storage.Get([]byte("existing2")) _, err = storage.Get([]byte("existing2"))
if err == nil || err != ErrKeyNotFound { if err == nil || err != ErrKeyNotFound {
t.Errorf("Expected existing2 to be deleted from storage, got: %v", err) t.Errorf("Expected existing2 to be deleted from storage, got: %v", err)
@ -142,24 +142,24 @@ func TestReadOnlyTransactionOperations(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
rwLock := &sync.RWMutex{} rwLock := &sync.RWMutex{}
// Prepare some initial data // Prepare some initial data
storage.Put([]byte("key1"), []byte("value1")) storage.Put([]byte("key1"), []byte("value1"))
// Create a read-only transaction // Create a read-only transaction
tx := &TransactionImpl{ tx := &TransactionImpl{
storage: storage, storage: storage,
mode: ReadOnly, mode: ReadOnly,
buffer: NewBuffer(), buffer: NewBuffer(),
rwLock: rwLock, rwLock: rwLock,
stats: statsCollector, stats: statsCollector,
} }
tx.active.Store(true) tx.active.Store(true)
// Actually acquire the read lock before setting the flag // Actually acquire the read lock before setting the flag
rwLock.RLock() rwLock.RLock()
tx.hasReadLock.Store(true) tx.hasReadLock.Store(true)
// Test Get // Test Get
value, err := tx.Get([]byte("key1")) value, err := tx.Get([]byte("key1"))
if err != nil { if err != nil {
@ -168,30 +168,30 @@ func TestReadOnlyTransactionOperations(t *testing.T) {
if !bytes.Equal(value, []byte("value1")) { if !bytes.Equal(value, []byte("value1")) {
t.Errorf("Expected value 'value1', got %s", value) t.Errorf("Expected value 'value1', got %s", value)
} }
// Test Put on read-only transaction (should fail) // Test Put on read-only transaction (should fail)
err = tx.Put([]byte("key2"), []byte("value2")) err = tx.Put([]byte("key2"), []byte("value2"))
if err == nil || err != ErrReadOnlyTransaction { if err == nil || err != ErrReadOnlyTransaction {
t.Errorf("Expected ErrReadOnlyTransaction, got %v", err) t.Errorf("Expected ErrReadOnlyTransaction, got %v", err)
} }
// Test Delete on read-only transaction (should fail) // Test Delete on read-only transaction (should fail)
err = tx.Delete([]byte("key1")) err = tx.Delete([]byte("key1"))
if err == nil || err != ErrReadOnlyTransaction { if err == nil || err != ErrReadOnlyTransaction {
t.Errorf("Expected ErrReadOnlyTransaction, got %v", err) t.Errorf("Expected ErrReadOnlyTransaction, got %v", err)
} }
// Test IsReadOnly // Test IsReadOnly
if !tx.IsReadOnly() { if !tx.IsReadOnly() {
t.Error("Expected IsReadOnly() to return true") t.Error("Expected IsReadOnly() to return true")
} }
// Test Commit on read-only transaction // Test Commit on read-only transaction
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
t.Errorf("Unexpected error committing read-only tx: %v", err) t.Errorf("Unexpected error committing read-only tx: %v", err)
} }
// After commit, the transaction should be closed // After commit, the transaction should be closed
_, err = tx.Get([]byte("key1")) _, err = tx.Get([]byte("key1"))
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
@ -203,47 +203,47 @@ func TestTransactionRollback(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
rwLock := &sync.RWMutex{} rwLock := &sync.RWMutex{}
// Prepare some initial data // Prepare some initial data
storage.Put([]byte("key1"), []byte("value1")) storage.Put([]byte("key1"), []byte("value1"))
// Create a transaction // Create a transaction
tx := &TransactionImpl{ tx := &TransactionImpl{
storage: storage, storage: storage,
mode: ReadWrite, mode: ReadWrite,
buffer: NewBuffer(), buffer: NewBuffer(),
rwLock: rwLock, rwLock: rwLock,
stats: statsCollector, stats: statsCollector,
} }
tx.active.Store(true) tx.active.Store(true)
// Actually acquire the write lock before setting the flag // Actually acquire the write lock before setting the flag
rwLock.Lock() rwLock.Lock()
tx.hasWriteLock.Store(true) tx.hasWriteLock.Store(true)
// Make some changes // Make some changes
err := tx.Put([]byte("key2"), []byte("value2")) err := tx.Put([]byte("key2"), []byte("value2"))
if err != nil { if err != nil {
t.Errorf("Unexpected error putting key: %v", err) t.Errorf("Unexpected error putting key: %v", err)
} }
err = tx.Delete([]byte("key1")) err = tx.Delete([]byte("key1"))
if err != nil { if err != nil {
t.Errorf("Unexpected error deleting key: %v", err) t.Errorf("Unexpected error deleting key: %v", err)
} }
// Rollback the transaction // Rollback the transaction
err = tx.Rollback() err = tx.Rollback()
if err != nil { if err != nil {
t.Errorf("Unexpected error rolling back tx: %v", err) t.Errorf("Unexpected error rolling back tx: %v", err)
} }
// After rollback, the transaction should be closed // After rollback, the transaction should be closed
_, err = tx.Get([]byte("key1")) _, err = tx.Get([]byte("key1"))
if err == nil || err != ErrTransactionClosed { if err == nil || err != ErrTransactionClosed {
t.Errorf("Expected ErrTransactionClosed, got %v", err) t.Errorf("Expected ErrTransactionClosed, got %v", err)
} }
// Verify changes were not applied to storage // Verify changes were not applied to storage
val, err := storage.Get([]byte("key1")) val, err := storage.Get([]byte("key1"))
if err != nil { if err != nil {
@ -252,7 +252,7 @@ func TestTransactionRollback(t *testing.T) {
if !bytes.Equal(val, []byte("value1")) { if !bytes.Equal(val, []byte("value1")) {
t.Errorf("Expected value 'value1' in storage, got %s", val) t.Errorf("Expected value 'value1' in storage, got %s", val)
} }
_, err = storage.Get([]byte("key2")) _, err = storage.Get([]byte("key2"))
if err == nil || err != ErrKeyNotFound { if err == nil || err != ErrKeyNotFound {
t.Errorf("Expected key2 to not exist in storage after rollback, got: %v", err) t.Errorf("Expected key2 to not exist in storage after rollback, got: %v", err)
@ -263,47 +263,47 @@ func TestTransactionIterators(t *testing.T) {
storage := NewMemoryStorage() storage := NewMemoryStorage()
statsCollector := &StatsCollectorMock{} statsCollector := &StatsCollectorMock{}
rwLock := &sync.RWMutex{} rwLock := &sync.RWMutex{}
// Prepare some initial data // Prepare some initial data
storage.Put([]byte("a"), []byte("value_a")) storage.Put([]byte("a"), []byte("value_a"))
storage.Put([]byte("c"), []byte("value_c")) storage.Put([]byte("c"), []byte("value_c"))
storage.Put([]byte("e"), []byte("value_e")) storage.Put([]byte("e"), []byte("value_e"))
// Create a transaction // Create a transaction
tx := &TransactionImpl{ tx := &TransactionImpl{
storage: storage, storage: storage,
mode: ReadWrite, mode: ReadWrite,
buffer: NewBuffer(), buffer: NewBuffer(),
rwLock: rwLock, rwLock: rwLock,
stats: statsCollector, stats: statsCollector,
} }
tx.active.Store(true) tx.active.Store(true)
// Actually acquire the write lock before setting the flag // Actually acquire the write lock before setting the flag
rwLock.Lock() rwLock.Lock()
tx.hasWriteLock.Store(true) tx.hasWriteLock.Store(true)
// Make some changes to the transaction buffer // Make some changes to the transaction buffer
tx.Put([]byte("b"), []byte("value_b")) tx.Put([]byte("b"), []byte("value_b"))
tx.Put([]byte("d"), []byte("value_d")) tx.Put([]byte("d"), []byte("value_d"))
tx.Delete([]byte("c")) // Delete an existing key tx.Delete([]byte("c")) // Delete an existing key
// Test full iterator // Test full iterator
it := tx.NewIterator() it := tx.NewIterator()
// Collect all keys and values // Collect all keys and values
var keys [][]byte var keys [][]byte
var values [][]byte var values [][]byte
for it.SeekToFirst(); it.Valid(); it.Next() { for it.SeekToFirst(); it.Valid(); it.Next() {
keys = append(keys, append([]byte{}, it.Key()...)) keys = append(keys, append([]byte{}, it.Key()...))
values = append(values, append([]byte{}, it.Value()...)) values = append(values, append([]byte{}, it.Value()...))
} }
// The iterator might still return the deleted key 'c' (with a tombstone marker) // The iterator might still return the deleted key 'c' (with a tombstone marker)
// Print the actual keys for debugging // Print the actual keys for debugging
t.Logf("Actual keys in iterator: %v", keys) t.Logf("Actual keys in iterator: %v", keys)
// Define expected keys (a, b, d, e) - c is deleted but might appear as a tombstone // Define expected keys (a, b, d, e) - c is deleted but might appear as a tombstone
expectedKeySet := map[string]bool{ expectedKeySet := map[string]bool{
"a": true, "a": true,
@ -311,7 +311,7 @@ func TestTransactionIterators(t *testing.T) {
"d": true, "d": true,
"e": true, "e": true,
} }
// Check each key is in our expected set // Check each key is in our expected set
for _, key := range keys { for _, key := range keys {
keyStr := string(key) keyStr := string(key)
@ -319,7 +319,7 @@ func TestTransactionIterators(t *testing.T) {
t.Errorf("Found unexpected key: %s", keyStr) t.Errorf("Found unexpected key: %s", keyStr)
} }
} }
// Verify we have at least our expected keys // Verify we have at least our expected keys
for k := range expectedKeySet { for k := range expectedKeySet {
found := false found := false
@ -333,29 +333,29 @@ func TestTransactionIterators(t *testing.T) {
t.Errorf("Expected key %s not found in iterator", k) t.Errorf("Expected key %s not found in iterator", k)
} }
} }
// Test range iterator // Test range iterator
rangeIt := tx.NewRangeIterator([]byte("b"), []byte("e")) rangeIt := tx.NewRangeIterator([]byte("b"), []byte("e"))
// Collect all keys and values in range // Collect all keys and values in range
keys = nil keys = nil
values = nil values = nil
for rangeIt.SeekToFirst(); rangeIt.Valid(); rangeIt.Next() { for rangeIt.SeekToFirst(); rangeIt.Valid(); rangeIt.Next() {
keys = append(keys, append([]byte{}, rangeIt.Key()...)) keys = append(keys, append([]byte{}, rangeIt.Key()...))
values = append(values, append([]byte{}, rangeIt.Value()...)) values = append(values, append([]byte{}, rangeIt.Value()...))
} }
// The range should include b and d, and might include c with a tombstone // The range should include b and d, and might include c with a tombstone
// Print the actual keys for debugging // Print the actual keys for debugging
t.Logf("Actual keys in range iterator: %v", keys) t.Logf("Actual keys in range iterator: %v", keys)
// Ensure the keys include our expected ones (b, d) // Ensure the keys include our expected ones (b, d)
expectedRangeSet := map[string]bool{ expectedRangeSet := map[string]bool{
"b": true, "b": true,
"d": true, "d": true,
} }
// Check each key is in our expected set (or is c which might appear as a tombstone) // Check each key is in our expected set (or is c which might appear as a tombstone)
for _, key := range keys { for _, key := range keys {
keyStr := string(key) keyStr := string(key)
@ -363,7 +363,7 @@ func TestTransactionIterators(t *testing.T) {
t.Errorf("Found unexpected key in range: %s", keyStr) t.Errorf("Found unexpected key in range: %s", keyStr)
} }
} }
// Verify we have at least our expected keys // Verify we have at least our expected keys
for k := range expectedRangeSet { for k := range expectedRangeSet {
found := false found := false
@ -377,17 +377,17 @@ func TestTransactionIterators(t *testing.T) {
t.Errorf("Expected key %s not found in range iterator", k) t.Errorf("Expected key %s not found in range iterator", k)
} }
} }
// Test iterator on closed transaction // Test iterator on closed transaction
tx.Commit() tx.Commit()
closedIt := tx.NewIterator() closedIt := tx.NewIterator()
if closedIt.Valid() { if closedIt.Valid() {
t.Error("Expected iterator on closed transaction to be invalid") t.Error("Expected iterator on closed transaction to be invalid")
} }
closedRangeIt := tx.NewRangeIterator([]byte("a"), []byte("z")) closedRangeIt := tx.NewRangeIterator([]byte("a"), []byte("z"))
if closedRangeIt.Valid() { if closedRangeIt.Valid() {
t.Error("Expected range iterator on closed transaction to be invalid") t.Error("Expected range iterator on closed transaction to be invalid")
} }
} }

View File

@ -112,7 +112,7 @@ func NewWAL(cfg *config.Config, dir string) (*WAL, error) {
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create WAL directory: %w", err) return nil, fmt.Errorf("failed to create WAL directory: %w", err)
} }
// Verify that the directory was successfully created // Verify that the directory was successfully created
if _, err := os.Stat(dir); os.IsNotExist(err) { if _, err := os.Stat(dir); os.IsNotExist(err) {
return nil, fmt.Errorf("WAL directory creation failed: %s does not exist after MkdirAll", dir) return nil, fmt.Errorf("WAL directory creation failed: %s does not exist after MkdirAll", dir)
@ -445,24 +445,24 @@ func (w *WAL) AppendExactBytes(rawBytes []byte, seqNum uint64) (uint64, error) {
} else if status == WALStatusRotating { } else if status == WALStatusRotating {
return 0, ErrWALRotating return 0, ErrWALRotating
} }
// Verify we have at least a header // Verify we have at least a header
if len(rawBytes) < HeaderSize { if len(rawBytes) < HeaderSize {
return 0, fmt.Errorf("raw WAL record too small: %d bytes", len(rawBytes)) return 0, fmt.Errorf("raw WAL record too small: %d bytes", len(rawBytes))
} }
// Extract payload size to validate record integrity // Extract payload size to validate record integrity
payloadSize := int(binary.LittleEndian.Uint16(rawBytes[4:6])) payloadSize := int(binary.LittleEndian.Uint16(rawBytes[4:6]))
if len(rawBytes) != HeaderSize + payloadSize { if len(rawBytes) != HeaderSize+payloadSize {
return 0, fmt.Errorf("raw WAL record size mismatch: header says %d payload bytes, but got %d total bytes", return 0, fmt.Errorf("raw WAL record size mismatch: header says %d payload bytes, but got %d total bytes",
payloadSize, len(rawBytes)) payloadSize, len(rawBytes))
} }
// Update nextSequence if the provided sequence is higher // Update nextSequence if the provided sequence is higher
if seqNum >= w.nextSequence { if seqNum >= w.nextSequence {
w.nextSequence = seqNum + 1 w.nextSequence = seqNum + 1
} }
// Write the raw bytes directly to the WAL // Write the raw bytes directly to the WAL
if _, err := w.writer.Write(rawBytes); err != nil { if _, err := w.writer.Write(rawBytes); err != nil {
return 0, fmt.Errorf("failed to write raw WAL record: %w", err) return 0, fmt.Errorf("failed to write raw WAL record: %w", err)
@ -471,7 +471,7 @@ func (w *WAL) AppendExactBytes(rawBytes []byte, seqNum uint64) (uint64, error) {
// Update bytes written // Update bytes written
w.bytesWritten += int64(len(rawBytes)) w.bytesWritten += int64(len(rawBytes))
w.batchByteSize += int64(len(rawBytes)) w.batchByteSize += int64(len(rawBytes))
// Notify observers (with a simplified Entry since we can't properly parse the raw bytes) // Notify observers (with a simplified Entry since we can't properly parse the raw bytes)
entry := &Entry{ entry := &Entry{
SequenceNumber: seqNum, SequenceNumber: seqNum,
@ -480,12 +480,12 @@ func (w *WAL) AppendExactBytes(rawBytes []byte, seqNum uint64) (uint64, error) {
Value: []byte{}, Value: []byte{},
} }
w.notifyEntryObservers(entry) w.notifyEntryObservers(entry)
// Sync if needed // Sync if needed
if err := w.maybeSync(); err != nil { if err := w.maybeSync(); err != nil {
return 0, err return 0, err
} }
return seqNum, nil return seqNum, nil
} }

View File

@ -596,9 +596,9 @@ func TestAppendWithSequence(t *testing.T) {
// Write entries with specific sequence numbers // Write entries with specific sequence numbers
testCases := []struct { testCases := []struct {
key string key string
value string value string
seqNum uint64 seqNum uint64
entryType uint8 entryType uint8
}{ }{
{"key1", "value1", 100, OpTypePut}, {"key1", "value1", 100, OpTypePut},
@ -787,27 +787,27 @@ func TestAppendBatchWithSequence(t *testing.T) {
if batch.Seq != startSeq { if batch.Seq != startSeq {
t.Errorf("Expected batch seq %d, got %d", startSeq, batch.Seq) t.Errorf("Expected batch seq %d, got %d", startSeq, batch.Seq)
} }
// Verify batch count // Verify batch count
if len(batch.Operations) != len(entries) { if len(batch.Operations) != len(entries) {
t.Errorf("Expected %d operations, got %d", len(entries), len(batch.Operations)) t.Errorf("Expected %d operations, got %d", len(entries), len(batch.Operations))
} }
// Verify batch operations // Verify batch operations
for i, op := range batch.Operations { for i, op := range batch.Operations {
if i < len(entries) { if i < len(entries) {
expected := entries[i] expected := entries[i]
if op.Type != expected.Type { if op.Type != expected.Type {
t.Errorf("Operation %d: expected type %d, got %d", i, expected.Type, op.Type) t.Errorf("Operation %d: expected type %d, got %d", i, expected.Type, op.Type)
} }
if string(op.Key) != string(expected.Key) { if string(op.Key) != string(expected.Key) {
t.Errorf("Operation %d: expected key %q, got %q", i, string(expected.Key), string(op.Key)) t.Errorf("Operation %d: expected key %q, got %q", i, string(expected.Key), string(op.Key))
} }
if expected.Type != OpTypeDelete && string(op.Value) != string(expected.Value) { if expected.Type != OpTypeDelete && string(op.Value) != string(expected.Value) {
t.Errorf("Operation %d: expected value %q, got %q", i, string(expected.Value), string(op.Value)) t.Errorf("Operation %d: expected value %q, got %q", i, string(expected.Value), string(op.Value))
} }
} }
} }
} else { } else {
t.Errorf("Failed to decode batch: %v", err) t.Errorf("Failed to decode batch: %v", err)
} }