From 72007886f74ce5bf594aae5c1da23286573422ab Mon Sep 17 00:00:00 2001 From: Jeremy Tregunna Date: Sun, 20 Apr 2025 00:08:16 -0600 Subject: [PATCH] feat: implement SQLite-inspired transaction model with reader-writer locks --- TODO.md | 22 +- pkg/engine/engine.go | 81 +++- pkg/engine/engine_test.go | 3 + pkg/engine/iterator.go | 389 +++++++++++++----- pkg/transaction/example_test.go | 129 ++++++ pkg/transaction/transaction.go | 45 +++ pkg/transaction/transaction_test.go | 322 +++++++++++++++ pkg/transaction/tx_impl.go | 571 +++++++++++++++++++++++++++ pkg/transaction/txbuffer/txbuffer.go | 270 +++++++++++++ pkg/wal/wal.go | 66 ++++ 10 files changed, 1790 insertions(+), 108 deletions(-) create mode 100644 pkg/transaction/example_test.go create mode 100644 pkg/transaction/transaction.go create mode 100644 pkg/transaction/transaction_test.go create mode 100644 pkg/transaction/tx_impl.go create mode 100644 pkg/transaction/txbuffer/txbuffer.go diff --git a/TODO.md b/TODO.md index 9815add..48f95e2 100644 --- a/TODO.md +++ b/TODO.md @@ -121,17 +121,17 @@ This document outlines the implementation tasks for the Go Storage Engine, organ - [x] Add efficient seeking capabilities - [x] Implement proper cleanup for resources -- [ ] Implement SQLite-inspired reader-writer concurrency - - [ ] Add reader-writer lock for basic isolation - - [ ] Implement WAL-based reads during active write transactions - - [ ] Design clean API for transaction handling - - [ ] Test concurrent read/write operations +- [x] Implement SQLite-inspired reader-writer concurrency + - [x] Add reader-writer lock for basic isolation + - [x] Implement WAL-based reads during active write transactions + - [x] Design clean API for transaction handling + - [x] Test concurrent read/write operations -- [ ] Implement atomic batch operations - - [ ] Create batch data structure for multiple operations - - [ ] Implement atomic batch commit to WAL - - [ ] Add crash recovery for batches - - [ ] Design extensible interfaces for future transaction support +- [x] Implement atomic batch operations + - [x] Create batch data structure for multiple operations + - [x] Implement atomic batch commit to WAL + - [x] Add crash recovery for batches + - [x] Design extensible interfaces for future transaction support - [ ] Add basic statistics and metrics - [ ] Implement counters for operations @@ -186,7 +186,7 @@ This document outlines the implementation tasks for the Go Storage Engine, organ - [ ] `Delete(ctx context.Context, key []byte, opts ...WriteOption) error` - [ ] `Batch(ctx context.Context, ops []Operation, opts ...WriteOption) error` - [ ] `NewIterator(opts IteratorOptions) Iterator` - - [ ] `BeginTransaction(readOnly bool) (Transaction, error)` + - [x] `BeginTransaction(readOnly bool) (Transaction, error)` - [ ] `Close() error` - [ ] Implement error types diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 6520b8a..5d23d1a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -57,8 +57,9 @@ type Engine struct { closed atomic.Bool // Concurrency control - mu sync.RWMutex - flushMu sync.Mutex + mu sync.RWMutex // Main lock for engine state + flushMu sync.Mutex // Lock for flushing operations + txLock sync.RWMutex // Lock for transaction isolation } // NewEngine creates a new storage engine @@ -520,6 +521,82 @@ func (e *Engine) loadSSTables() error { return nil } +// GetRWLock returns the transaction lock for this engine +func (e *Engine) GetRWLock() *sync.RWMutex { + return &e.txLock +} + +// ApplyBatch atomically applies a batch of operations +func (e *Engine) ApplyBatch(entries []*wal.Entry) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.closed.Load() { + return ErrEngineClosed + } + + // Append batch to WAL + startSeqNum, err := e.wal.AppendBatch(entries) + if err != nil { + return fmt.Errorf("failed to append batch to WAL: %w", err) + } + + // Apply each entry to the MemTable + for i, entry := range entries { + seqNum := startSeqNum + uint64(i) + + switch entry.Type { + case wal.OpTypePut: + e.memTablePool.Put(entry.Key, entry.Value, seqNum) + case wal.OpTypeDelete: + e.memTablePool.Delete(entry.Key, seqNum) + // If compaction manager exists, also track this tombstone + if e.compactionMgr != nil { + e.compactionMgr.TrackTombstone(entry.Key) + } + } + + e.lastSeqNum = seqNum + } + + // Check if MemTable needs to be flushed + if e.memTablePool.IsFlushNeeded() { + if err := e.scheduleFlush(); err != nil { + return fmt.Errorf("failed to schedule flush: %w", err) + } + } + + return nil +} + +// GetIterator returns an iterator over the entire keyspace +func (e *Engine) GetIterator() (Iterator, error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.closed.Load() { + return nil, ErrEngineClosed + } + + // Create a hierarchical iterator that combines all sources + return newHierarchicalIterator(e), nil +} + +// GetRangeIterator returns an iterator limited to a specific key range +func (e *Engine) GetRangeIterator(startKey, endKey []byte) (Iterator, error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.closed.Load() { + return nil, ErrEngineClosed + } + + // Create a hierarchical iterator with range bounds + iter := newHierarchicalIterator(e) + iter.SetBounds(startKey, endKey) + return iter, nil +} + // Close closes the storage engine func (e *Engine) Close() error { // First set the closed flag - use atomic operation to prevent race conditions diff --git a/pkg/engine/engine_test.go b/pkg/engine/engine_test.go index ee1597b..ae10304 100644 --- a/pkg/engine/engine_test.go +++ b/pkg/engine/engine_test.go @@ -250,6 +250,9 @@ func TestEngine_GetIterator(t *testing.T) { {"d", "4"}, } + // Need to seek to first position + rangeIter.SeekToFirst() + // Now test the range iterator i = 0 for rangeIter.Valid() { diff --git a/pkg/engine/iterator.go b/pkg/engine/iterator.go index 30927b1..fa0cb51 100644 --- a/pkg/engine/iterator.go +++ b/pkg/engine/iterator.go @@ -5,7 +5,6 @@ import ( "container/heap" "sync" - "git.canoozie.net/jer/go-storage/pkg/iterator" "git.canoozie.net/jer/go-storage/pkg/memtable" "git.canoozie.net/jer/go-storage/pkg/sstable" ) @@ -14,25 +13,25 @@ import ( type Iterator interface { // SeekToFirst positions the iterator at the first key SeekToFirst() - + // SeekToLast positions the iterator at the last key SeekToLast() - + // Seek positions the iterator at the first key >= target Seek(target []byte) bool - + // Next advances the iterator to the next key Next() bool - + // Key returns the current key Key() []byte - + // Value returns the current value Value() []byte - + // Valid returns true if the iterator is positioned at a valid entry Valid() bool - + // IsTombstone returns true if the current entry is a deletion marker // This is used during compaction to distinguish between a regular nil value and a tombstone IsTombstone() bool @@ -42,11 +41,11 @@ type Iterator interface { type iterHeapItem struct { // The original source iterator source IterSource - + // The current key and value key []byte value []byte - + // Internal heap index index int } @@ -88,7 +87,7 @@ func (h *iterHeap) Pop() interface{} { type IterSource interface { // GetIterator returns an iterator for this source GetIterator() Iterator - + // GetLevel returns the level of this source (lower is newer) GetLevel() int } @@ -138,21 +137,21 @@ func (a *MemTableIterAdapter) SeekToLast() { // This is an inefficient implementation because the MemTable iterator // doesn't directly support SeekToLast. We simulate it by scanning to the end. a.iter.SeekToFirst() - + // If no items, return early if !a.iter.Valid() { return } - + // Store the last key we've seen var lastKey []byte - + // Scan to find the last element for a.iter.Valid() { lastKey = a.iter.Key() a.iter.Next() } - + // Re-position at the last key we found if lastKey != nil { a.iter.Seek(lastKey) @@ -183,7 +182,7 @@ func (a *MemTableIterAdapter) Value() []byte { if !a.Valid() { return nil } - + // Check if this is a tombstone (deletion marker) if a.iter.IsTombstone() { // Special case: return nil but with a marker that this is a tombstone @@ -191,7 +190,7 @@ func (a *MemTableIterAdapter) Value() []byte { // See memtable.Iterator.IsTombstone() for details return nil } - + return a.iter.Value() } @@ -349,8 +348,8 @@ func (m *MergedIterator) SeekToLast() { key := iter.Key() // If this is a new maximum key, or the same key but from a newer level - if lastKey == nil || - bytes.Compare(key, lastKey) > 0 || + if lastKey == nil || + bytes.Compare(key, lastKey) > 0 || (bytes.Equal(key, lastKey) && m.sources[i].GetLevel() < lastLevel) { lastKey = key lastValue = iter.Value() @@ -462,11 +461,11 @@ func (m *MergedIterator) Valid() bool { func (m *MergedIterator) IsTombstone() bool { m.mu.Lock() defer m.mu.Unlock() - + if m.current == nil { return false } - + // In a MergedIterator, we need to check if the source iterator marks this as a tombstone for _, source := range m.sources { if source == m.current.source { @@ -474,7 +473,7 @@ func (m *MergedIterator) IsTombstone() bool { return iter.IsTombstone() } } - + return false } @@ -496,7 +495,7 @@ func (m *MergedIterator) advanceHeap() { m.current = heap.Pop(&m.heap).(*iterHeapItem) // Skip any entries with duplicate keys (keeping the one from the newest source) - // Sources are already provided in newest-to-oldest order, and we've popped + // Sources are already provided in newest-to-oldest order, and we've popped // the smallest key, so any item in the heap with the same key is from an older source currentKey := m.current.key for len(m.heap) > 0 && bytes.Equal(m.heap[0].key, currentKey) { @@ -521,92 +520,290 @@ func (m *MergedIterator) advanceHeap() { } } -// GetIterator returns an iterator over the entire database -func (e *Engine) GetIterator() (Iterator, error) { - e.mu.RLock() - defer e.mu.RUnlock() - - if e.closed.Load() { - return nil, ErrEngineClosed - } - +// newHierarchicalIterator creates a new hierarchical iterator for the engine +func newHierarchicalIterator(e *Engine) *boundedIterator { // Get all MemTables from the pool memTables := e.memTablePool.GetMemTables() - - // Create a list of all iterator sources in newest-to-oldest order - sources := make([]IterSource, 0, len(memTables)+len(e.sstables)) + + // Create a list of all iterators in newest-to-oldest order + iters := make([]Iterator, 0, len(memTables)+len(e.sstables)) // Add MemTables (active first, then immutables) - for i, table := range memTables { - sources = append(sources, &MemTableSource{ - mem: table, - level: i, // Level corresponds to position in the list - }) + for _, table := range memTables { + iters = append(iters, newMemTableIterAdapter(table.NewIterator())) } - // Add SSTables (levels after MemTables) - baseLevel := len(memTables) + // Add SSTables (from newest to oldest) for i := len(e.sstables) - 1; i >= 0; i-- { - sources = append(sources, &SSTableSource{ - sst: e.sstables[i], - level: baseLevel + (len(e.sstables) - 1 - i), - }) + iters = append(iters, newSSTableIterAdapter(e.sstables[i].NewIterator())) } - // Convert sources to actual iterators - iters := make([]iterator.Iterator, 0, len(sources)) - for _, src := range sources { - iters = append(iters, src.GetIterator()) - } - - // Create and return a hierarchical iterator that understands LSM-tree structure - return iterator.NewHierarchicalIterator(iters), nil -} - -// GetRangeIterator returns an iterator over a specific key range -func (e *Engine) GetRangeIterator(start, end []byte) (Iterator, error) { - iter, err := e.GetIterator() - if err != nil { - return nil, err - } - - // Position at the start key - if start != nil { - if !iter.Seek(start) { - // No keys in range - return iter, nil - } + // Wrap in a bounded iterator (unbounded by default) + // If we have no iterators, use an empty one + var baseIter Iterator + if len(iters) == 0 { + baseIter = &emptyIterator{} + } else if len(iters) == 1 { + baseIter = iters[0] } else { - iter.SeekToFirst() - if !iter.Valid() { - // Empty database - return iter, nil - } + // Create a simple chained iterator for now that checks each source in order + baseIter = &chainedIterator{iterators: iters} } - // If we have an end key, wrap the iterator to limit the range - if end != nil { - iter = &boundedIterator{ - Iterator: iter, - end: end, - } + return &boundedIterator{ + Iterator: baseIter, + end: nil, // No end bound by default } - - return iter, nil } +// chainedIterator is a simple iterator that checks multiple sources in order +type chainedIterator struct { + iterators []Iterator + current int +} + +func (c *chainedIterator) SeekToFirst() { + if len(c.iterators) == 0 { + return + } + + // Position all iterators at their first key + for _, iter := range c.iterators { + iter.SeekToFirst() + } + + // Find the first valid iterator with the smallest key + c.current = -1 + var smallestKey []byte + + for i, iter := range c.iterators { + if !iter.Valid() { + continue + } + + if c.current == -1 || bytes.Compare(iter.Key(), smallestKey) < 0 { + c.current = i + smallestKey = iter.Key() + } + } +} + +func (c *chainedIterator) SeekToLast() { + if len(c.iterators) == 0 { + return + } + + // Position all iterators at their last key + for _, iter := range c.iterators { + iter.SeekToLast() + } + + // Find the first valid iterator with the largest key + c.current = -1 + var largestKey []byte + + for i, iter := range c.iterators { + if !iter.Valid() { + continue + } + + if c.current == -1 || bytes.Compare(iter.Key(), largestKey) > 0 { + c.current = i + largestKey = iter.Key() + } + } +} + +func (c *chainedIterator) Seek(target []byte) bool { + if len(c.iterators) == 0 { + return false + } + + // Position all iterators at or after the target key + for _, iter := range c.iterators { + iter.Seek(target) + } + + // Find the first valid iterator with the smallest key >= target + c.current = -1 + var smallestKey []byte + + for i, iter := range c.iterators { + if !iter.Valid() { + continue + } + + if c.current == -1 || bytes.Compare(iter.Key(), smallestKey) < 0 { + c.current = i + smallestKey = iter.Key() + } + } + + return c.current != -1 +} + +func (c *chainedIterator) Next() bool { + if !c.Valid() { + return false + } + + // Get the current key + currentKey := c.iterators[c.current].Key() + + // Advance all iterators that are at the current key + for _, iter := range c.iterators { + if iter.Valid() && bytes.Equal(iter.Key(), currentKey) { + iter.Next() + } + } + + // Find the next valid iterator with the smallest key + c.current = -1 + var smallestKey []byte + + for i, iter := range c.iterators { + if !iter.Valid() { + continue + } + + if c.current == -1 || bytes.Compare(iter.Key(), smallestKey) < 0 { + c.current = i + smallestKey = iter.Key() + } + } + + return c.current != -1 +} + +func (c *chainedIterator) Key() []byte { + if !c.Valid() { + return nil + } + return c.iterators[c.current].Key() +} + +func (c *chainedIterator) Value() []byte { + if !c.Valid() { + return nil + } + return c.iterators[c.current].Value() +} + +func (c *chainedIterator) Valid() bool { + return c.current != -1 && c.current < len(c.iterators) && c.iterators[c.current].Valid() +} + +func (c *chainedIterator) IsTombstone() bool { + if !c.Valid() { + return false + } + return c.iterators[c.current].IsTombstone() +} + +// emptyIterator is an iterator that contains no entries +type emptyIterator struct{} + +func (e *emptyIterator) SeekToFirst() {} +func (e *emptyIterator) SeekToLast() {} +func (e *emptyIterator) Seek(target []byte) bool { return false } +func (e *emptyIterator) Next() bool { return false } +func (e *emptyIterator) Key() []byte { return nil } +func (e *emptyIterator) Value() []byte { return nil } +func (e *emptyIterator) Valid() bool { return false } +func (e *emptyIterator) IsTombstone() bool { return false } + +// Note: This is now replaced by the more comprehensive implementation in engine.go +// The hierarchical iterator code remains here to avoid impacting other code references + // boundedIterator wraps an iterator and limits it to a specific range type boundedIterator struct { Iterator - end []byte + start []byte + end []byte +} + +// SetBounds sets the start and end bounds for the iterator +func (b *boundedIterator) SetBounds(start, end []byte) { + // Make copies of the bounds to avoid external modification + if start != nil { + b.start = make([]byte, len(start)) + copy(b.start, start) + } else { + b.start = nil + } + + if end != nil { + b.end = make([]byte, len(end)) + copy(b.end, end) + } else { + b.end = nil + } + + // If we already have a valid position, check if it's still in bounds + if b.Iterator.Valid() { + b.checkBounds() + } } func (b *boundedIterator) SeekToFirst() { - b.Iterator.SeekToFirst() + if b.start != nil { + // If we have a start bound, seek to it + b.Iterator.Seek(b.start) + } else { + // Otherwise seek to the first key + b.Iterator.SeekToFirst() + } + b.checkBounds() +} + +func (b *boundedIterator) SeekToLast() { + if b.end != nil { + // If we have an end bound, seek to it + // The current implementation might not be efficient for finding the last + // key before the end bound, but it works for now + b.Iterator.Seek(b.end) + + // If we landed exactly at the end bound, back up one + if b.Iterator.Valid() && bytes.Equal(b.Iterator.Key(), b.end) { + // We need to back up because end is exclusive + // This is inefficient but correct + b.Iterator.SeekToFirst() + + // Scan to find the last key before the end bound + var lastKey []byte + for b.Iterator.Valid() && bytes.Compare(b.Iterator.Key(), b.end) < 0 { + lastKey = b.Iterator.Key() + b.Iterator.Next() + } + + if lastKey != nil { + b.Iterator.Seek(lastKey) + } else { + // No keys before the end bound + b.Iterator.SeekToFirst() + // This will be marked invalid by checkBounds + } + } + } else { + // No end bound, seek to the last key + b.Iterator.SeekToLast() + } + + // Verify we're within bounds b.checkBounds() } func (b *boundedIterator) Seek(target []byte) bool { + // If target is before start bound, use start bound instead + if b.start != nil && bytes.Compare(target, b.start) < 0 { + target = b.start + } + + // If target is at or after end bound, the seek will fail + if b.end != nil && bytes.Compare(target, b.end) >= 0 { + return false + } + if b.Iterator.Seek(target) { return b.checkBounds() } @@ -618,12 +815,12 @@ func (b *boundedIterator) Next() bool { if !b.checkBounds() { return false } - + // Then try to advance if !b.Iterator.Next() { return false } - + // Check if the new position is within bounds return b.checkBounds() } @@ -658,14 +855,16 @@ func (b *boundedIterator) checkBounds() bool { if !b.Iterator.Valid() { return false } - - // Check if the current key is beyond the end bound - if b.end != nil && len(b.end) > 0 { - // For a range query [start, end), the end key is exclusive - if bytes.Compare(b.Iterator.Key(), b.end) >= 0 { - return false - } + + // Check if the current key is before the start bound + if b.start != nil && bytes.Compare(b.Iterator.Key(), b.start) < 0 { + return false } - + + // Check if the current key is beyond the end bound + if b.end != nil && bytes.Compare(b.Iterator.Key(), b.end) >= 0 { + return false + } + return true -} \ No newline at end of file +} diff --git a/pkg/transaction/example_test.go b/pkg/transaction/example_test.go new file mode 100644 index 0000000..03884b6 --- /dev/null +++ b/pkg/transaction/example_test.go @@ -0,0 +1,129 @@ +package transaction_test + +import ( + "fmt" + "os" + + "git.canoozie.net/jer/go-storage/pkg/engine" + "git.canoozie.net/jer/go-storage/pkg/transaction" +) + +func Example() { + // Create a temporary directory for the example + tempDir, err := os.MkdirTemp("", "transaction_example_*") + if err != nil { + fmt.Printf("Failed to create temp directory: %v\n", err) + return + } + defer os.RemoveAll(tempDir) + + // Create a new storage engine + eng, err := engine.NewEngine(tempDir) + if err != nil { + fmt.Printf("Failed to create engine: %v\n", err) + return + } + defer eng.Close() + + // Add some initial data directly to the engine + if err := eng.Put([]byte("user:1001"), []byte("Alice")); err != nil { + fmt.Printf("Failed to add user: %v\n", err) + return + } + if err := eng.Put([]byte("user:1002"), []byte("Bob")); err != nil { + fmt.Printf("Failed to add user: %v\n", err) + return + } + + // Create a read-only transaction + readTx, err := transaction.NewTransaction(eng, transaction.ReadOnly) + if err != nil { + fmt.Printf("Failed to create read transaction: %v\n", err) + return + } + + // Query data using the read transaction + value, err := readTx.Get([]byte("user:1001")) + if err != nil { + fmt.Printf("Failed to get user: %v\n", err) + } else { + fmt.Printf("Read transaction found user: %s\n", value) + } + + // Create an iterator to scan all users + fmt.Println("All users (read transaction):") + iter := readTx.NewIterator() + for iter.SeekToFirst(); iter.Valid(); iter.Next() { + fmt.Printf(" %s: %s\n", iter.Key(), iter.Value()) + } + + // Commit the read transaction + if err := readTx.Commit(); err != nil { + fmt.Printf("Failed to commit read transaction: %v\n", err) + return + } + + // Create a read-write transaction + writeTx, err := transaction.NewTransaction(eng, transaction.ReadWrite) + if err != nil { + fmt.Printf("Failed to create write transaction: %v\n", err) + return + } + + // Modify data within the transaction + if err := writeTx.Put([]byte("user:1003"), []byte("Charlie")); err != nil { + fmt.Printf("Failed to add user: %v\n", err) + return + } + if err := writeTx.Delete([]byte("user:1001")); err != nil { + fmt.Printf("Failed to delete user: %v\n", err) + return + } + + // Changes are visible within the transaction + fmt.Println("All users (write transaction before commit):") + iter = writeTx.NewIterator() + for iter.SeekToFirst(); iter.Valid(); iter.Next() { + fmt.Printf(" %s: %s\n", iter.Key(), iter.Value()) + } + + // But not in the main engine yet + val, err := eng.Get([]byte("user:1003")) + if err != nil { + fmt.Println("New user not yet visible in engine (correct)") + } else { + fmt.Printf("Unexpected: user visible before commit: %s\n", val) + } + + // Commit the write transaction + if err := writeTx.Commit(); err != nil { + fmt.Printf("Failed to commit write transaction: %v\n", err) + return + } + + // Now changes are visible in the engine + fmt.Println("All users (after commit):") + users := []string{"user:1001", "user:1002", "user:1003"} + for _, key := range users { + val, err := eng.Get([]byte(key)) + if err != nil { + fmt.Printf(" %s: \n", key) + } else { + fmt.Printf(" %s: %s\n", key, val) + } + } + + // Output: + // Read transaction found user: Alice + // All users (read transaction): + // user:1001: Alice + // user:1002: Bob + // All users (write transaction before commit): + // user:1002: Bob + // user:1003: Charlie + // New user not yet visible in engine (correct) + // All users (after commit): + // user:1001: + // user:1002: Bob + // user:1003: Charlie +} \ No newline at end of file diff --git a/pkg/transaction/transaction.go b/pkg/transaction/transaction.go new file mode 100644 index 0000000..c18f0a9 --- /dev/null +++ b/pkg/transaction/transaction.go @@ -0,0 +1,45 @@ +package transaction + +import ( + "git.canoozie.net/jer/go-storage/pkg/engine" +) + +// TransactionMode defines the transaction access mode (ReadOnly or ReadWrite) +type TransactionMode int + +const ( + // ReadOnly transactions only read from the database + ReadOnly TransactionMode = iota + + // ReadWrite transactions can both read and write to the database + ReadWrite +) + +// Transaction represents a database transaction that provides ACID guarantees +// It follows an SQLite-inspired concurrency model with reader-writer locks +type Transaction interface { + // Get retrieves a value for the given key + Get(key []byte) ([]byte, error) + + // Put adds or updates a key-value pair (only for ReadWrite transactions) + Put(key, value []byte) error + + // Delete removes a key (only for ReadWrite transactions) + Delete(key []byte) error + + // NewIterator returns an iterator for all keys in the transaction + NewIterator() engine.Iterator + + // NewRangeIterator returns an iterator limited to the given key range + NewRangeIterator(startKey, endKey []byte) engine.Iterator + + // Commit makes all changes permanent + // For ReadOnly transactions, this just releases resources + Commit() error + + // Rollback discards all transaction changes + Rollback() error + + // IsReadOnly returns true if this is a read-only transaction + IsReadOnly() bool +} \ No newline at end of file diff --git a/pkg/transaction/transaction_test.go b/pkg/transaction/transaction_test.go new file mode 100644 index 0000000..78e45db --- /dev/null +++ b/pkg/transaction/transaction_test.go @@ -0,0 +1,322 @@ +package transaction + +import ( + "bytes" + "os" + "testing" + + "git.canoozie.net/jer/go-storage/pkg/engine" +) + +func setupTestEngine(t *testing.T) (*engine.Engine, string) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "transaction_test_*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + + // Create a new engine + eng, err := engine.NewEngine(tempDir) + if err != nil { + os.RemoveAll(tempDir) + t.Fatalf("Failed to create engine: %v", err) + } + + return eng, tempDir +} + +func TestReadOnlyTransaction(t *testing.T) { + eng, tempDir := setupTestEngine(t) + defer os.RemoveAll(tempDir) + defer eng.Close() + + // Add some data directly to the engine + if err := eng.Put([]byte("key1"), []byte("value1")); err != nil { + t.Fatalf("Failed to put key1: %v", err) + } + if err := eng.Put([]byte("key2"), []byte("value2")); err != nil { + t.Fatalf("Failed to put key2: %v", err) + } + + // Create a read-only transaction + tx, err := NewTransaction(eng, ReadOnly) + if err != nil { + t.Fatalf("Failed to create read-only transaction: %v", err) + } + + // Test Get functionality + value, err := tx.Get([]byte("key1")) + if err != nil { + t.Fatalf("Failed to get key1: %v", err) + } + if !bytes.Equal(value, []byte("value1")) { + t.Errorf("Expected 'value1' but got '%s'", value) + } + + // Test read-only constraints + err = tx.Put([]byte("key3"), []byte("value3")) + if err != ErrReadOnlyTransaction { + t.Errorf("Expected ErrReadOnlyTransaction but got: %v", err) + } + + err = tx.Delete([]byte("key1")) + if err != ErrReadOnlyTransaction { + t.Errorf("Expected ErrReadOnlyTransaction but got: %v", err) + } + + // Test iterator + iter := tx.NewIterator() + count := 0 + for iter.SeekToFirst(); iter.Valid(); iter.Next() { + count++ + } + if count != 2 { + t.Errorf("Expected 2 keys but found %d", count) + } + + // Test commit (which for read-only just releases resources) + if err := tx.Commit(); err != nil { + t.Errorf("Failed to commit read-only transaction: %v", err) + } + + // Transaction should be closed now + _, err = tx.Get([]byte("key1")) + if err != ErrTransactionClosed { + t.Errorf("Expected ErrTransactionClosed but got: %v", err) + } +} + +func TestReadWriteTransaction(t *testing.T) { + eng, tempDir := setupTestEngine(t) + defer os.RemoveAll(tempDir) + defer eng.Close() + + // Add initial data + if err := eng.Put([]byte("key1"), []byte("value1")); err != nil { + t.Fatalf("Failed to put key1: %v", err) + } + + // Create a read-write transaction + tx, err := NewTransaction(eng, ReadWrite) + if err != nil { + t.Fatalf("Failed to create read-write transaction: %v", err) + } + + // Add more data through the transaction + if err := tx.Put([]byte("key2"), []byte("value2")); err != nil { + t.Fatalf("Failed to put key2: %v", err) + } + if err := tx.Put([]byte("key3"), []byte("value3")); err != nil { + t.Fatalf("Failed to put key3: %v", err) + } + + // Delete a key + if err := tx.Delete([]byte("key1")); err != nil { + t.Fatalf("Failed to delete key1: %v", err) + } + + // Verify the changes are visible in the transaction but not in the engine yet + // Check via transaction + value, err := tx.Get([]byte("key2")) + if err != nil { + t.Errorf("Failed to get key2 from transaction: %v", err) + } + if !bytes.Equal(value, []byte("value2")) { + t.Errorf("Expected 'value2' but got '%s'", value) + } + + // Check deleted key + _, err = tx.Get([]byte("key1")) + if err == nil { + t.Errorf("key1 should be deleted in transaction") + } + + // Check directly in engine - changes shouldn't be visible yet + value, err = eng.Get([]byte("key2")) + if err == nil { + t.Errorf("key2 should not be visible in engine yet") + } + + value, err = eng.Get([]byte("key1")) + if err != nil { + t.Errorf("key1 should still be visible in engine: %v", err) + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + t.Fatalf("Failed to commit transaction: %v", err) + } + + // Now check engine again - changes should be visible + value, err = eng.Get([]byte("key2")) + if err != nil { + t.Errorf("key2 should be visible in engine after commit: %v", err) + } + if !bytes.Equal(value, []byte("value2")) { + t.Errorf("Expected 'value2' but got '%s'", value) + } + + // Deleted key should be gone + value, err = eng.Get([]byte("key1")) + if err == nil { + t.Errorf("key1 should be deleted in engine after commit") + } + + // Transaction should be closed + _, err = tx.Get([]byte("key2")) + if err != ErrTransactionClosed { + t.Errorf("Expected ErrTransactionClosed but got: %v", err) + } +} + +func TestTransactionRollback(t *testing.T) { + eng, tempDir := setupTestEngine(t) + defer os.RemoveAll(tempDir) + defer eng.Close() + + // Add initial data + if err := eng.Put([]byte("key1"), []byte("value1")); err != nil { + t.Fatalf("Failed to put key1: %v", err) + } + + // Create a read-write transaction + tx, err := NewTransaction(eng, ReadWrite) + if err != nil { + t.Fatalf("Failed to create read-write transaction: %v", err) + } + + // Add and modify data + if err := tx.Put([]byte("key2"), []byte("value2")); err != nil { + t.Fatalf("Failed to put key2: %v", err) + } + if err := tx.Delete([]byte("key1")); err != nil { + t.Fatalf("Failed to delete key1: %v", err) + } + + // Rollback the transaction + if err := tx.Rollback(); err != nil { + t.Fatalf("Failed to rollback transaction: %v", err) + } + + // Changes should not be visible in the engine + value, err := eng.Get([]byte("key1")) + if err != nil { + t.Errorf("key1 should still exist after rollback: %v", err) + } + if !bytes.Equal(value, []byte("value1")) { + t.Errorf("Expected 'value1' but got '%s'", value) + } + + // key2 should not exist + _, err = eng.Get([]byte("key2")) + if err == nil { + t.Errorf("key2 should not exist after rollback") + } + + // Transaction should be closed + _, err = tx.Get([]byte("key1")) + if err != ErrTransactionClosed { + t.Errorf("Expected ErrTransactionClosed but got: %v", err) + } +} + +func TestTransactionIterator(t *testing.T) { + eng, tempDir := setupTestEngine(t) + defer os.RemoveAll(tempDir) + defer eng.Close() + + // Add initial data + if err := eng.Put([]byte("key1"), []byte("value1")); err != nil { + t.Fatalf("Failed to put key1: %v", err) + } + if err := eng.Put([]byte("key3"), []byte("value3")); err != nil { + t.Fatalf("Failed to put key3: %v", err) + } + if err := eng.Put([]byte("key5"), []byte("value5")); err != nil { + t.Fatalf("Failed to put key5: %v", err) + } + + // Create a read-write transaction + tx, err := NewTransaction(eng, ReadWrite) + if err != nil { + t.Fatalf("Failed to create read-write transaction: %v", err) + } + + // Add and modify data in transaction + if err := tx.Put([]byte("key2"), []byte("value2")); err != nil { + t.Fatalf("Failed to put key2: %v", err) + } + if err := tx.Put([]byte("key4"), []byte("value4")); err != nil { + t.Fatalf("Failed to put key4: %v", err) + } + if err := tx.Delete([]byte("key3")); err != nil { + t.Fatalf("Failed to delete key3: %v", err) + } + + // Use iterator to check order and content + iter := tx.NewIterator() + expected := []struct { + key string + value string + }{ + {"key1", "value1"}, + {"key2", "value2"}, + {"key4", "value4"}, + {"key5", "value5"}, + } + + i := 0 + for iter.SeekToFirst(); iter.Valid(); iter.Next() { + if i >= len(expected) { + t.Errorf("Too many keys in iterator") + break + } + + if !bytes.Equal(iter.Key(), []byte(expected[i].key)) { + t.Errorf("Expected key '%s' but got '%s'", expected[i].key, string(iter.Key())) + } + if !bytes.Equal(iter.Value(), []byte(expected[i].value)) { + t.Errorf("Expected value '%s' but got '%s'", expected[i].value, string(iter.Value())) + } + i++ + } + + if i != len(expected) { + t.Errorf("Expected %d keys but found %d", len(expected), i) + } + + // Test range iterator + rangeIter := tx.NewRangeIterator([]byte("key2"), []byte("key5")) + expected = []struct { + key string + value string + }{ + {"key2", "value2"}, + {"key4", "value4"}, + } + + i = 0 + for rangeIter.SeekToFirst(); rangeIter.Valid(); rangeIter.Next() { + if i >= len(expected) { + t.Errorf("Too many keys in range iterator") + break + } + + if !bytes.Equal(rangeIter.Key(), []byte(expected[i].key)) { + t.Errorf("Expected key '%s' but got '%s'", expected[i].key, string(rangeIter.Key())) + } + if !bytes.Equal(rangeIter.Value(), []byte(expected[i].value)) { + t.Errorf("Expected value '%s' but got '%s'", expected[i].value, string(rangeIter.Value())) + } + i++ + } + + if i != len(expected) { + t.Errorf("Expected %d keys in range but found %d", len(expected), i) + } + + // Commit and verify results + if err := tx.Commit(); err != nil { + t.Fatalf("Failed to commit transaction: %v", err) + } +} \ No newline at end of file diff --git a/pkg/transaction/tx_impl.go b/pkg/transaction/tx_impl.go new file mode 100644 index 0000000..1ad7028 --- /dev/null +++ b/pkg/transaction/tx_impl.go @@ -0,0 +1,571 @@ +package transaction + +import ( + "bytes" + "errors" + "sync" + "sync/atomic" + + "git.canoozie.net/jer/go-storage/pkg/engine" + "git.canoozie.net/jer/go-storage/pkg/transaction/txbuffer" + "git.canoozie.net/jer/go-storage/pkg/wal" +) + +// Common errors for transaction operations +var ( + ErrReadOnlyTransaction = errors.New("cannot write to a read-only transaction") + ErrTransactionClosed = errors.New("transaction already committed or rolled back") +) + +// EngineTransaction implements a SQLite-inspired transaction using reader-writer locks +type EngineTransaction struct { + // Reference to the main engine + engine *engine.Engine + + // Transaction mode (ReadOnly or ReadWrite) + mode TransactionMode + + // Buffer for transaction operations + buffer *txbuffer.TxBuffer + + // For read-write transactions, tracks if we have the write lock + writeLock *sync.RWMutex + + // Tracks if the transaction is still active + active int32 + + // For read-only transactions, ensures we release the read lock exactly once + readUnlocked int32 +} + +// NewTransaction creates a new transaction +func NewTransaction(eng *engine.Engine, mode TransactionMode) (*EngineTransaction, error) { + tx := &EngineTransaction{ + engine: eng, + mode: mode, + buffer: txbuffer.NewTxBuffer(), + active: 1, + } + + // For read-write transactions, we need a write lock + if mode == ReadWrite { + // Get the engine's lock - we'll use the same one for all transactions + lock := eng.GetRWLock() + + // Acquire the write lock + lock.Lock() + tx.writeLock = lock + } else { + // For read-only transactions, just acquire a read lock + lock := eng.GetRWLock() + lock.RLock() + tx.writeLock = lock + } + + return tx, nil +} + +// Get retrieves a value for the given key +func (tx *EngineTransaction) Get(key []byte) ([]byte, error) { + if atomic.LoadInt32(&tx.active) == 0 { + return nil, ErrTransactionClosed + } + + // First check the transaction buffer for any pending changes + if val, found := tx.buffer.Get(key); found { + if val == nil { + // This is a deletion marker + return nil, engine.ErrKeyNotFound + } + return val, nil + } + + // Not in the buffer, get from the underlying engine + return tx.engine.Get(key) +} + +// Put adds or updates a key-value pair +func (tx *EngineTransaction) Put(key, value []byte) error { + if atomic.LoadInt32(&tx.active) == 0 { + return ErrTransactionClosed + } + + if tx.mode == ReadOnly { + return ErrReadOnlyTransaction + } + + // Buffer the change - it will be applied on commit + tx.buffer.Put(key, value) + return nil +} + +// Delete removes a key +func (tx *EngineTransaction) Delete(key []byte) error { + if atomic.LoadInt32(&tx.active) == 0 { + return ErrTransactionClosed + } + + if tx.mode == ReadOnly { + return ErrReadOnlyTransaction + } + + // Buffer the deletion - it will be applied on commit + tx.buffer.Delete(key) + return nil +} + +// NewIterator returns an iterator that first reads from the transaction buffer +// and then from the underlying engine +func (tx *EngineTransaction) NewIterator() engine.Iterator { + if atomic.LoadInt32(&tx.active) == 0 { + // Return an empty iterator if transaction is closed + return &emptyIterator{} + } + + // Get the engine iterator for the entire keyspace + engineIter, err := tx.engine.GetIterator() + if err != nil { + // If we can't get an engine iterator, return a buffer-only iterator + return tx.buffer.NewIterator() + } + + // If there are no changes in the buffer, just use the engine's iterator + if tx.buffer.Size() == 0 { + return engineIter + } + + // Create a transaction iterator that merges buffer changes with engine state + return newTransactionIterator(tx.buffer, engineIter) +} + +// NewRangeIterator returns an iterator limited to a specific key range +func (tx *EngineTransaction) NewRangeIterator(startKey, endKey []byte) engine.Iterator { + if atomic.LoadInt32(&tx.active) == 0 { + // Return an empty iterator if transaction is closed + return &emptyIterator{} + } + + // Get the engine iterator for the range + engineIter, err := tx.engine.GetRangeIterator(startKey, endKey) + if err != nil { + // If we can't get an engine iterator, use a buffer-only iterator + // and apply range bounds to it + bufferIter := tx.buffer.NewIterator() + return newRangeIterator(bufferIter, startKey, endKey) + } + + // If there are no changes in the buffer, just use the engine's range iterator + if tx.buffer.Size() == 0 { + return engineIter + } + + // Create a transaction iterator that merges buffer changes with engine state + mergedIter := newTransactionIterator(tx.buffer, engineIter) + + // Apply range constraints + return newRangeIterator(mergedIter, startKey, endKey) +} + +// transactionIterator merges a transaction buffer with the engine state +type transactionIterator struct { + bufferIter *txbuffer.Iterator + engineIter engine.Iterator + currentKey []byte + isValid bool + isBuffer bool // true if current position is from buffer +} + +// newTransactionIterator creates a new iterator that merges buffer and engine state +func newTransactionIterator(buffer *txbuffer.TxBuffer, engineIter engine.Iterator) *transactionIterator { + return &transactionIterator{ + bufferIter: buffer.NewIterator(), + engineIter: engineIter, + isValid: false, + } +} + +// SeekToFirst positions at the first key in either the buffer or engine +func (it *transactionIterator) SeekToFirst() { + it.bufferIter.SeekToFirst() + it.engineIter.SeekToFirst() + it.selectNext() +} + +// SeekToLast positions at the last key in either the buffer or engine +func (it *transactionIterator) SeekToLast() { + it.bufferIter.SeekToLast() + it.engineIter.SeekToLast() + it.selectPrev() +} + +// Seek positions at the first key >= target +func (it *transactionIterator) Seek(target []byte) bool { + it.bufferIter.Seek(target) + it.engineIter.Seek(target) + it.selectNext() + return it.isValid +} + +// Next advances to the next key +func (it *transactionIterator) Next() bool { + // If we're currently at a buffer key, advance it + if it.isValid && it.isBuffer { + it.bufferIter.Next() + } else if it.isValid { + // If we're at an engine key, advance it + it.engineIter.Next() + } + + it.selectNext() + return it.isValid +} + +// Key returns the current key +func (it *transactionIterator) Key() []byte { + if !it.isValid { + return nil + } + + return it.currentKey +} + +// Value returns the current value +func (it *transactionIterator) Value() []byte { + if !it.isValid { + return nil + } + + if it.isBuffer { + return it.bufferIter.Value() + } + + return it.engineIter.Value() +} + +// Valid returns true if the iterator is valid +func (it *transactionIterator) Valid() bool { + return it.isValid +} + +// IsTombstone returns true if the current entry is a deletion marker +func (it *transactionIterator) IsTombstone() bool { + if !it.isValid { + return false + } + + if it.isBuffer { + return it.bufferIter.IsTombstone() + } + + return it.engineIter.IsTombstone() +} + +// selectNext finds the next valid position in the merged view +func (it *transactionIterator) selectNext() { + // First check if either iterator is valid + bufferValid := it.bufferIter.Valid() + engineValid := it.engineIter.Valid() + + if !bufferValid && !engineValid { + // Neither is valid, so we're done + it.isValid = false + it.currentKey = nil + it.isBuffer = false + return + } + + if !bufferValid { + // Only engine is valid, so use it + it.isValid = true + it.currentKey = it.engineIter.Key() + it.isBuffer = false + return + } + + if !engineValid { + // Only buffer is valid, so use it + // Check if this is a deletion marker + if it.bufferIter.IsTombstone() { + // Skip the tombstone and move to the next valid position + it.bufferIter.Next() + it.selectNext() // Recursively find the next valid position + return + } + + it.isValid = true + it.currentKey = it.bufferIter.Key() + it.isBuffer = true + return + } + + // Both are valid, so compare keys + bufferKey := it.bufferIter.Key() + engineKey := it.engineIter.Key() + + cmp := bytes.Compare(bufferKey, engineKey) + + if cmp < 0 { + // Buffer key is smaller, use it + // Check if this is a deletion marker + if it.bufferIter.IsTombstone() { + // Skip the tombstone + it.bufferIter.Next() + it.selectNext() // Recursively find the next valid position + return + } + + it.isValid = true + it.currentKey = bufferKey + it.isBuffer = true + } else if cmp > 0 { + // Engine key is smaller, use it + it.isValid = true + it.currentKey = engineKey + it.isBuffer = false + } else { + // Keys are the same, buffer takes precedence + // If buffer has a tombstone, we need to skip both + if it.bufferIter.IsTombstone() { + // Skip both iterators for this key + it.bufferIter.Next() + it.engineIter.Next() + it.selectNext() // Recursively find the next valid position + return + } + + it.isValid = true + it.currentKey = bufferKey + it.isBuffer = true + + // Need to advance engine iterator to avoid duplication + it.engineIter.Next() + } +} + +// selectPrev finds the previous valid position in the merged view +// This is a fairly inefficient implementation for now +func (it *transactionIterator) selectPrev() { + // This implementation is not efficient but works for now + // We actually just rebuild the full ordering and scan to the end + it.SeekToFirst() + + // If already invalid, just return + if !it.isValid { + return + } + + // Scan to the last key + var lastKey []byte + var isBuffer bool + + for it.isValid { + lastKey = it.currentKey + isBuffer = it.isBuffer + it.Next() + } + + // Reposition at the last key we found + if lastKey != nil { + it.isValid = true + it.currentKey = lastKey + it.isBuffer = isBuffer + } +} + +// rangeIterator applies range bounds to an existing iterator +type rangeIterator struct { + engine.Iterator + startKey []byte + endKey []byte +} + +// newRangeIterator creates a new range-limited iterator +func newRangeIterator(iter engine.Iterator, startKey, endKey []byte) *rangeIterator { + ri := &rangeIterator{ + Iterator: iter, + } + + // Make copies of bounds + if startKey != nil { + ri.startKey = make([]byte, len(startKey)) + copy(ri.startKey, startKey) + } + + if endKey != nil { + ri.endKey = make([]byte, len(endKey)) + copy(ri.endKey, endKey) + } + + return ri +} + +// SeekToFirst seeks to the range start or the first key +func (ri *rangeIterator) SeekToFirst() { + if ri.startKey != nil { + ri.Iterator.Seek(ri.startKey) + } else { + ri.Iterator.SeekToFirst() + } + ri.checkBounds() +} + +// Seek seeks to the target or range start +func (ri *rangeIterator) Seek(target []byte) bool { + // If target is before range start, use range start + if ri.startKey != nil && bytes.Compare(target, ri.startKey) < 0 { + target = ri.startKey + } + + // If target is at or after range end, fail + if ri.endKey != nil && bytes.Compare(target, ri.endKey) >= 0 { + return false + } + + if ri.Iterator.Seek(target) { + return ri.checkBounds() + } + return false +} + +// Next advances to the next key within bounds +func (ri *rangeIterator) Next() bool { + if !ri.checkBounds() { + return false + } + + if !ri.Iterator.Next() { + return false + } + + return ri.checkBounds() +} + +// Valid checks if the iterator is valid and within bounds +func (ri *rangeIterator) Valid() bool { + return ri.Iterator.Valid() && ri.checkBounds() +} + +// checkBounds ensures the current position is within range bounds +func (ri *rangeIterator) checkBounds() bool { + if !ri.Iterator.Valid() { + return false + } + + // Check start bound + if ri.startKey != nil && bytes.Compare(ri.Iterator.Key(), ri.startKey) < 0 { + return false + } + + // Check end bound + if ri.endKey != nil && bytes.Compare(ri.Iterator.Key(), ri.endKey) >= 0 { + return false + } + + return true +} + +// Commit makes all changes permanent +func (tx *EngineTransaction) Commit() error { + // Only proceed if the transaction is still active + if !atomic.CompareAndSwapInt32(&tx.active, 1, 0) { + return ErrTransactionClosed + } + + var err error + + // For read-only transactions, just release the read lock + if tx.mode == ReadOnly { + tx.releaseReadLock() + return nil + } + + // For read-write transactions, apply the changes + if tx.buffer.Size() > 0 { + // Get operations from the buffer + ops := tx.buffer.Operations() + + // Create a batch for all operations + walBatch := make([]*wal.Entry, 0, len(ops)) + + // Build WAL entries for each operation + for _, op := range ops { + if op.IsDelete { + // Create delete entry + walBatch = append(walBatch, &wal.Entry{ + Type: wal.OpTypeDelete, + Key: op.Key, + }) + } else { + // Create put entry + walBatch = append(walBatch, &wal.Entry{ + Type: wal.OpTypePut, + Key: op.Key, + Value: op.Value, + }) + } + } + + // Apply the batch atomically + err = tx.engine.ApplyBatch(walBatch) + } + + // Release the write lock + if tx.writeLock != nil { + tx.writeLock.Unlock() + tx.writeLock = nil + } + + return err +} + +// Rollback discards all transaction changes +func (tx *EngineTransaction) Rollback() error { + // Only proceed if the transaction is still active + if !atomic.CompareAndSwapInt32(&tx.active, 1, 0) { + return ErrTransactionClosed + } + + // Clear the buffer + tx.buffer.Clear() + + // Release locks based on transaction mode + if tx.mode == ReadOnly { + tx.releaseReadLock() + } else { + // Release write lock + if tx.writeLock != nil { + tx.writeLock.Unlock() + tx.writeLock = nil + } + } + + return nil +} + +// IsReadOnly returns true if this is a read-only transaction +func (tx *EngineTransaction) IsReadOnly() bool { + return tx.mode == ReadOnly +} + +// releaseReadLock safely releases the read lock for read-only transactions +func (tx *EngineTransaction) releaseReadLock() { + // Only release once to avoid panics from multiple unlocks + if atomic.CompareAndSwapInt32(&tx.readUnlocked, 0, 1) { + if tx.writeLock != nil { + tx.writeLock.RUnlock() + tx.writeLock = nil + } + } +} + +// Simple empty iterator implementation for closed transactions +type emptyIterator struct{} + +func (e *emptyIterator) SeekToFirst() {} +func (e *emptyIterator) SeekToLast() {} +func (e *emptyIterator) Seek([]byte) bool { return false } +func (e *emptyIterator) Next() bool { return false } +func (e *emptyIterator) Key() []byte { return nil } +func (e *emptyIterator) Value() []byte { return nil } +func (e *emptyIterator) Valid() bool { return false } +func (e *emptyIterator) IsTombstone() bool { return false } \ No newline at end of file diff --git a/pkg/transaction/txbuffer/txbuffer.go b/pkg/transaction/txbuffer/txbuffer.go new file mode 100644 index 0000000..8d8fbaf --- /dev/null +++ b/pkg/transaction/txbuffer/txbuffer.go @@ -0,0 +1,270 @@ +package txbuffer + +import ( + "bytes" + "sync" +) + +// Operation represents a single transaction operation (put or delete) +type Operation struct { + // Key is the key being operated on + Key []byte + + // Value is the value to set (nil for delete operations) + Value []byte + + // IsDelete is true for deletion operations + IsDelete bool +} + +// TxBuffer maintains a buffer of transaction operations before they are committed +type TxBuffer struct { + // Buffers all operations for the transaction + operations []Operation + + // Cache of key -> value for fast lookups without scanning the operation list + // Maps to nil for deletion markers + cache map[string][]byte + + // Protects against concurrent access + mu sync.RWMutex +} + +// NewTxBuffer creates a new transaction buffer +func NewTxBuffer() *TxBuffer { + return &TxBuffer{ + operations: make([]Operation, 0, 16), + cache: make(map[string][]byte), + } +} + +// Put adds a key-value pair to the transaction buffer +func (b *TxBuffer) Put(key, value []byte) { + b.mu.Lock() + defer b.mu.Unlock() + + // Create a safe copy of key and value to prevent later modifications + keyCopy := make([]byte, len(key)) + copy(keyCopy, key) + + valueCopy := make([]byte, len(value)) + copy(valueCopy, value) + + // Add to operations list + b.operations = append(b.operations, Operation{ + Key: keyCopy, + Value: valueCopy, + IsDelete: false, + }) + + // Update cache + b.cache[string(keyCopy)] = valueCopy +} + +// Delete marks a key as deleted in the transaction buffer +func (b *TxBuffer) Delete(key []byte) { + b.mu.Lock() + defer b.mu.Unlock() + + // Create a safe copy of the key + keyCopy := make([]byte, len(key)) + copy(keyCopy, key) + + // Add to operations list + b.operations = append(b.operations, Operation{ + Key: keyCopy, + Value: nil, + IsDelete: true, + }) + + // Update cache to mark key as deleted (nil value) + b.cache[string(keyCopy)] = nil +} + +// Get retrieves a value from the transaction buffer +// Returns (value, true) if found, (nil, false) if not found +func (b *TxBuffer) Get(key []byte) ([]byte, bool) { + b.mu.RLock() + defer b.mu.RUnlock() + + value, found := b.cache[string(key)] + return value, found +} + +// Has returns true if the key exists in the buffer, even if it's marked for deletion +func (b *TxBuffer) Has(key []byte) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + _, found := b.cache[string(key)] + return found +} + +// IsDeleted returns true if the key is marked for deletion in the buffer +func (b *TxBuffer) IsDeleted(key []byte) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + value, found := b.cache[string(key)] + return found && value == nil +} + +// Operations returns the list of all operations in the transaction +// This is used when committing the transaction +func (b *TxBuffer) Operations() []Operation { + b.mu.RLock() + defer b.mu.RUnlock() + + // Return a copy to prevent modification + result := make([]Operation, len(b.operations)) + copy(result, b.operations) + return result +} + +// Clear empties the transaction buffer +// Used when rolling back a transaction +func (b *TxBuffer) Clear() { + b.mu.Lock() + defer b.mu.Unlock() + + b.operations = b.operations[:0] + b.cache = make(map[string][]byte) +} + +// Size returns the number of operations in the buffer +func (b *TxBuffer) Size() int { + b.mu.RLock() + defer b.mu.RUnlock() + + return len(b.operations) +} + +// Iterator returns an iterator over the transaction buffer +type Iterator struct { + // The buffer this iterator is iterating over + buffer *TxBuffer + + // The current position in the keys slice + pos int + + // Sorted list of keys + keys []string +} + +// NewIterator creates a new iterator over the transaction buffer +func (b *TxBuffer) NewIterator() *Iterator { + b.mu.RLock() + defer b.mu.RUnlock() + + // Get all keys and sort them + keys := make([]string, 0, len(b.cache)) + for k := range b.cache { + keys = append(keys, k) + } + + // Sort the keys + keys = sortStrings(keys) + + return &Iterator{ + buffer: b, + pos: -1, // Start before the first position + keys: keys, + } +} + +// SeekToFirst positions the iterator at the first key +func (it *Iterator) SeekToFirst() { + it.pos = 0 +} + +// SeekToLast positions the iterator at the last key +func (it *Iterator) SeekToLast() { + if len(it.keys) > 0 { + it.pos = len(it.keys) - 1 + } else { + it.pos = 0 + } +} + +// Seek positions the iterator at the first key >= target +func (it *Iterator) Seek(target []byte) bool { + targetStr := string(target) + + // Binary search would be more efficient for large sets + for i, key := range it.keys { + if key >= targetStr { + it.pos = i + return true + } + } + + // Not found - position past the end + it.pos = len(it.keys) + return false +} + +// Next advances the iterator to the next key +func (it *Iterator) Next() bool { + if it.pos < 0 { + it.pos = 0 + return it.pos < len(it.keys) + } + + it.pos++ + return it.pos < len(it.keys) +} + +// Key returns the current key +func (it *Iterator) Key() []byte { + if !it.Valid() { + return nil + } + + return []byte(it.keys[it.pos]) +} + +// Value returns the current value +func (it *Iterator) Value() []byte { + if !it.Valid() { + return nil + } + + // Get the value from the buffer + it.buffer.mu.RLock() + defer it.buffer.mu.RUnlock() + + value := it.buffer.cache[it.keys[it.pos]] + return value // Returns nil for deletion markers +} + +// Valid returns true if the iterator is positioned at a valid entry +func (it *Iterator) Valid() bool { + return it.pos >= 0 && it.pos < len(it.keys) +} + +// IsTombstone returns true if the current entry is a deletion marker +func (it *Iterator) IsTombstone() bool { + if !it.Valid() { + return false + } + + it.buffer.mu.RLock() + defer it.buffer.mu.RUnlock() + + // The value is nil for tombstones in our cache implementation + value := it.buffer.cache[it.keys[it.pos]] + return value == nil +} + +// Simple implementation of string sorting for the iterator +func sortStrings(strings []string) []string { + // In-place sort + for i := 0; i < len(strings); i++ { + for j := i + 1; j < len(strings); j++ { + if bytes.Compare([]byte(strings[i]), []byte(strings[j])) > 0 { + strings[i], strings[j] = strings[j], strings[i] + } + } + } + return strings +} \ No newline at end of file diff --git a/pkg/wal/wal.go b/pkg/wal/wal.go index e5189b3..387ea80 100644 --- a/pkg/wal/wal.go +++ b/pkg/wal/wal.go @@ -362,6 +362,72 @@ func (w *WAL) Sync() error { return w.syncLocked() } +// AppendBatch adds a batch of entries to the WAL +func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.closed { + return 0, ErrWALClosed + } + + if len(entries) == 0 { + return w.nextSequence, nil + } + + // Start sequence number for the batch + startSeqNum := w.nextSequence + + // Record this as a batch operation with the number of entries + batchHeader := make([]byte, 1+8+4) // opType(1) + seqNum(8) + entryCount(4) + offset := 0 + + // Write operation type (batch) + batchHeader[offset] = OpTypeBatch + offset++ + + // Write sequence number + binary.LittleEndian.PutUint64(batchHeader[offset:offset+8], startSeqNum) + offset += 8 + + // Write entry count + binary.LittleEndian.PutUint32(batchHeader[offset:offset+4], uint32(len(entries))) + + // Write the batch header + if err := w.writeRawRecord(RecordTypeFull, batchHeader); err != nil { + return 0, fmt.Errorf("failed to write batch header: %w", err) + } + + // Process each entry in the batch + for i, entry := range entries { + // Assign sequential sequence numbers to each entry + seqNum := startSeqNum + uint64(i) + + // Write the entry + if entry.Value == nil { + // Deletion + if err := w.writeRecord(RecordTypeFull, OpTypeDelete, seqNum, entry.Key, nil); err != nil { + return 0, fmt.Errorf("failed to write entry %d: %w", i, err) + } + } else { + // Put + if err := w.writeRecord(RecordTypeFull, OpTypePut, seqNum, entry.Key, entry.Value); err != nil { + return 0, fmt.Errorf("failed to write entry %d: %w", i, err) + } + } + } + + // Update next sequence number + w.nextSequence = startSeqNum + uint64(len(entries)) + + // Sync if needed + if err := w.maybeSync(); err != nil { + return 0, err + } + + return startSeqNum, nil +} + // Close closes the WAL func (w *WAL) Close() error { w.mu.Lock()