From b1f9ed6740781a1e04bac57919ef3e4c40d8972b Mon Sep 17 00:00:00 2001 From: Jeremy Tregunna Date: Sat, 19 Apr 2025 15:13:59 -0600 Subject: [PATCH] feat: implement memtable package --- TODO.md | 26 +-- pkg/memtable/bench_test.go | 132 +++++++++++++++ pkg/memtable/mempool.go | 166 ++++++++++++++++++ pkg/memtable/mempool_test.go | 225 +++++++++++++++++++++++++ pkg/memtable/memtable.go | 155 +++++++++++++++++ pkg/memtable/memtable_test.go | 202 ++++++++++++++++++++++ pkg/memtable/recovery.go | 91 ++++++++++ pkg/memtable/recovery_test.go | 276 ++++++++++++++++++++++++++++++ pkg/memtable/skiplist.go | 308 ++++++++++++++++++++++++++++++++++ pkg/memtable/skiplist_test.go | 232 +++++++++++++++++++++++++ 10 files changed, 1800 insertions(+), 13 deletions(-) create mode 100644 pkg/memtable/bench_test.go create mode 100644 pkg/memtable/mempool.go create mode 100644 pkg/memtable/mempool_test.go create mode 100644 pkg/memtable/memtable.go create mode 100644 pkg/memtable/memtable_test.go create mode 100644 pkg/memtable/recovery.go create mode 100644 pkg/memtable/recovery_test.go create mode 100644 pkg/memtable/skiplist.go create mode 100644 pkg/memtable/skiplist_test.go diff --git a/TODO.md b/TODO.md index 719922f..84c6542 100644 --- a/TODO.md +++ b/TODO.md @@ -35,21 +35,21 @@ This document outlines the implementation tasks for the Go Storage Engine, organ ## Phase B: In-Memory Layer -- [ ] Implement MemTable - - [ ] Create skip list data structure aligned to 64-byte cache lines - - [ ] Add key/value insertion and lookup operations - - [ ] Implement sorted key iteration - - [ ] Add size tracking for flush threshold detection +- [✓] Implement MemTable + - [✓] Create skip list data structure aligned to 64-byte cache lines + - [✓] Add key/value insertion and lookup operations + - [✓] Implement sorted key iteration + - [✓] Add size tracking for flush threshold detection -- [ ] Connect WAL replay to MemTable - - [ ] Create recovery logic to rebuild MemTable from WAL - - [ ] Implement consistent snapshot reads during recovery - - [ ] Handle errors during replay with appropriate fallbacks +- [✓] Connect WAL replay to MemTable + - [✓] Create recovery logic to rebuild MemTable from WAL + - [✓] Implement consistent snapshot reads during recovery + - [✓] Handle errors during replay with appropriate fallbacks -- [ ] Test concurrent read/write scenarios - - [ ] Verify reader isolation during writes - - [ ] Test snapshot consistency guarantees - - [ ] Benchmark read/write performance under load +- [✓] Test concurrent read/write scenarios + - [✓] Verify reader isolation during writes + - [✓] Test snapshot consistency guarantees + - [✓] Benchmark read/write performance under load ## Phase C: Persistent Storage diff --git a/pkg/memtable/bench_test.go b/pkg/memtable/bench_test.go new file mode 100644 index 0000000..6030522 --- /dev/null +++ b/pkg/memtable/bench_test.go @@ -0,0 +1,132 @@ +package memtable + +import ( + "fmt" + "math/rand" + "strconv" + "testing" +) + +func BenchmarkSkipListInsert(b *testing.B) { + sl := NewSkipList() + + // Create random keys ahead of time + keys := make([][]byte, b.N) + values := make([][]byte, b.N) + for i := 0; i < b.N; i++ { + keys[i] = []byte(fmt.Sprintf("key-%d", i)) + values[i] = []byte(fmt.Sprintf("value-%d", i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + e := newEntry(keys[i], values[i], TypeValue, uint64(i)) + sl.Insert(e) + } +} + +func BenchmarkSkipListFind(b *testing.B) { + sl := NewSkipList() + + // Insert entries first + const numEntries = 100000 + keys := make([][]byte, numEntries) + for i := 0; i < numEntries; i++ { + key := []byte(fmt.Sprintf("key-%d", i)) + value := []byte(fmt.Sprintf("value-%d", i)) + keys[i] = key + sl.Insert(newEntry(key, value, TypeValue, uint64(i))) + } + + // Create random keys for lookup + lookupKeys := make([][]byte, b.N) + r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility + for i := 0; i < b.N; i++ { + idx := r.Intn(numEntries) + lookupKeys[i] = keys[idx] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + sl.Find(lookupKeys[i]) + } +} + +func BenchmarkMemTablePut(b *testing.B) { + mt := NewMemTable() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := []byte("key-" + strconv.Itoa(i)) + value := []byte("value-" + strconv.Itoa(i)) + mt.Put(key, value, uint64(i)) + } +} + +func BenchmarkMemTableGet(b *testing.B) { + mt := NewMemTable() + + // Insert entries first + const numEntries = 100000 + keys := make([][]byte, numEntries) + for i := 0; i < numEntries; i++ { + key := []byte(fmt.Sprintf("key-%d", i)) + value := []byte(fmt.Sprintf("value-%d", i)) + keys[i] = key + mt.Put(key, value, uint64(i)) + } + + // Create random keys for lookup + lookupKeys := make([][]byte, b.N) + r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility + for i := 0; i < b.N; i++ { + idx := r.Intn(numEntries) + lookupKeys[i] = keys[idx] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mt.Get(lookupKeys[i]) + } +} + +func BenchmarkMemPoolGet(b *testing.B) { + cfg := createTestConfig() + cfg.MemTableSize = 1024 * 1024 * 32 // 32MB for benchmark + pool := NewMemTablePool(cfg) + + // Create multiple memtables with entries + const entriesPerTable = 50000 + const numTables = 3 + keys := make([][]byte, entriesPerTable*numTables) + + // Fill tables + for t := 0; t < numTables; t++ { + // Fill a table + for i := 0; i < entriesPerTable; i++ { + idx := t*entriesPerTable + i + key := []byte(fmt.Sprintf("key-%d", idx)) + value := []byte(fmt.Sprintf("value-%d", idx)) + keys[idx] = key + pool.Put(key, value, uint64(idx)) + } + + // Switch to a new memtable (except for last one) + if t < numTables-1 { + pool.SwitchToNewMemTable() + } + } + + // Create random keys for lookup + lookupKeys := make([][]byte, b.N) + r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility + for i := 0; i < b.N; i++ { + idx := r.Intn(entriesPerTable * numTables) + lookupKeys[i] = keys[idx] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pool.Get(lookupKeys[i]) + } +} \ No newline at end of file diff --git a/pkg/memtable/mempool.go b/pkg/memtable/mempool.go new file mode 100644 index 0000000..ded9b37 --- /dev/null +++ b/pkg/memtable/mempool.go @@ -0,0 +1,166 @@ +package memtable + +import ( + "sync" + "sync/atomic" + "time" + + "git.canoozie.net/jer/go-storage/pkg/config" +) + +// MemTablePool manages a pool of MemTables +// It maintains one active MemTable and a set of immutable MemTables +type MemTablePool struct { + cfg *config.Config + active *MemTable + immutables []*MemTable + maxAge time.Duration + maxSize int64 + totalSize int64 + flushPending atomic.Bool + mu sync.RWMutex +} + +// NewMemTablePool creates a new MemTable pool +func NewMemTablePool(cfg *config.Config) *MemTablePool { + return &MemTablePool{ + cfg: cfg, + active: NewMemTable(), + immutables: make([]*MemTable, 0, cfg.MaxMemTables-1), + maxAge: time.Duration(cfg.MaxMemTableAge) * time.Second, + maxSize: cfg.MemTableSize, + } +} + +// Put adds a key-value pair to the active MemTable +func (p *MemTablePool) Put(key, value []byte, seqNum uint64) { + p.mu.RLock() + p.active.Put(key, value, seqNum) + p.mu.RUnlock() + + // Check if we need to flush after this write + p.checkFlushConditions() +} + +// Delete marks a key as deleted in the active MemTable +func (p *MemTablePool) Delete(key []byte, seqNum uint64) { + p.mu.RLock() + p.active.Delete(key, seqNum) + p.mu.RUnlock() + + // Check if we need to flush after this write + p.checkFlushConditions() +} + +// Get retrieves the value for a key from all MemTables +// Checks the active MemTable first, then the immutables in reverse order +func (p *MemTablePool) Get(key []byte) ([]byte, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + + // Check active table first + if value, found := p.active.Get(key); found { + return value, true + } + + // Check immutable tables in reverse order (newest first) + for i := len(p.immutables) - 1; i >= 0; i-- { + if value, found := p.immutables[i].Get(key); found { + return value, true + } + } + + return nil, false +} + +// ImmutableCount returns the number of immutable MemTables +func (p *MemTablePool) ImmutableCount() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.immutables) +} + +// checkFlushConditions checks if we need to flush the active MemTable +func (p *MemTablePool) checkFlushConditions() { + needsFlush := false + + p.mu.RLock() + defer p.mu.RUnlock() + + // Skip if a flush is already pending + if p.flushPending.Load() { + return + } + + // Check size condition + if p.active.ApproximateSize() >= p.maxSize { + needsFlush = true + } + + // Check age condition + if p.maxAge > 0 && p.active.Age() > p.maxAge.Seconds() { + needsFlush = true + } + + // Mark as needing flush if conditions met + if needsFlush { + p.flushPending.Store(true) + } +} + +// SwitchToNewMemTable makes the active MemTable immutable and creates a new active one +// Returns the immutable MemTable that needs to be flushed +func (p *MemTablePool) SwitchToNewMemTable() *MemTable { + p.mu.Lock() + defer p.mu.Unlock() + + // Reset the flush pending flag + p.flushPending.Store(false) + + // Make the current active table immutable + oldActive := p.active + oldActive.SetImmutable() + + // Create a new active table + p.active = NewMemTable() + + // Add the old table to the immutables list + p.immutables = append(p.immutables, oldActive) + + // Return the table that needs to be flushed + return oldActive +} + +// GetImmutablesForFlush returns a list of immutable MemTables ready for flushing +// and removes them from the pool +func (p *MemTablePool) GetImmutablesForFlush() []*MemTable { + p.mu.Lock() + defer p.mu.Unlock() + + result := p.immutables + p.immutables = make([]*MemTable, 0, p.cfg.MaxMemTables-1) + return result +} + +// IsFlushNeeded returns true if a flush is needed +func (p *MemTablePool) IsFlushNeeded() bool { + return p.flushPending.Load() +} + +// GetNextSequenceNumber returns the next sequence number to use +func (p *MemTablePool) GetNextSequenceNumber() uint64 { + p.mu.RLock() + defer p.mu.RUnlock() + return p.active.GetNextSequenceNumber() +} + +// GetMemTables returns all MemTables (active and immutable) +func (p *MemTablePool) GetMemTables() []*MemTable { + p.mu.RLock() + defer p.mu.RUnlock() + + result := make([]*MemTable, 0, len(p.immutables)+1) + result = append(result, p.active) + result = append(result, p.immutables...) + return result +} \ No newline at end of file diff --git a/pkg/memtable/mempool_test.go b/pkg/memtable/mempool_test.go new file mode 100644 index 0000000..dcadc44 --- /dev/null +++ b/pkg/memtable/mempool_test.go @@ -0,0 +1,225 @@ +package memtable + +import ( + "testing" + "time" + + "git.canoozie.net/jer/go-storage/pkg/config" +) + +func createTestConfig() *config.Config { + cfg := config.NewDefaultConfig("/tmp/db") + cfg.MemTableSize = 1024 // Small size for testing + cfg.MaxMemTableAge = 1 // 1 second + cfg.MaxMemTables = 4 // Allow up to 4 memtables + cfg.MemTablePoolCap = 4 // Pool capacity + return cfg +} + +func TestMemPoolBasicOperations(t *testing.T) { + cfg := createTestConfig() + pool := NewMemTablePool(cfg) + + // Test Put and Get + pool.Put([]byte("key1"), []byte("value1"), 1) + + value, found := pool.Get([]byte("key1")) + if !found { + t.Fatalf("expected to find key1, but got not found") + } + if string(value) != "value1" { + t.Errorf("expected value1, got %s", string(value)) + } + + // Test Delete + pool.Delete([]byte("key1"), 2) + + value, found = pool.Get([]byte("key1")) + if !found { + t.Fatalf("expected tombstone to be found for key1") + } + if value != nil { + t.Errorf("expected nil value for deleted key, got %v", value) + } +} + +func TestMemPoolSwitchMemTable(t *testing.T) { + cfg := createTestConfig() + pool := NewMemTablePool(cfg) + + // Add data to the active memtable + pool.Put([]byte("key1"), []byte("value1"), 1) + + // Switch to a new memtable + old := pool.SwitchToNewMemTable() + if !old.IsImmutable() { + t.Errorf("expected switched memtable to be immutable") + } + + // Verify the data is in the old table + value, found := old.Get([]byte("key1")) + if !found { + t.Fatalf("expected to find key1 in old table, but got not found") + } + if string(value) != "value1" { + t.Errorf("expected value1 in old table, got %s", string(value)) + } + + // Verify the immutable count is correct + if count := pool.ImmutableCount(); count != 1 { + t.Errorf("expected immutable count to be 1, got %d", count) + } + + // Add data to the new active memtable + pool.Put([]byte("key2"), []byte("value2"), 2) + + // Verify we can still retrieve data from both tables + value, found = pool.Get([]byte("key1")) + if !found { + t.Fatalf("expected to find key1 through pool, but got not found") + } + if string(value) != "value1" { + t.Errorf("expected value1 through pool, got %s", string(value)) + } + + value, found = pool.Get([]byte("key2")) + if !found { + t.Fatalf("expected to find key2 through pool, but got not found") + } + if string(value) != "value2" { + t.Errorf("expected value2 through pool, got %s", string(value)) + } +} + +func TestMemPoolFlushConditions(t *testing.T) { + // Create a config with small thresholds for testing + cfg := createTestConfig() + cfg.MemTableSize = 100 // Very small size to trigger flush + pool := NewMemTablePool(cfg) + + // Initially no flush should be needed + if pool.IsFlushNeeded() { + t.Errorf("expected no flush needed initially") + } + + // Add enough data to trigger a size-based flush + for i := 0; i < 10; i++ { + key := []byte{byte(i)} + value := make([]byte, 20) // 20 bytes per value + pool.Put(key, value, uint64(i+1)) + } + + // Should trigger a flush + if !pool.IsFlushNeeded() { + t.Errorf("expected flush needed after reaching size threshold") + } + + // Switch to a new memtable + old := pool.SwitchToNewMemTable() + if !old.IsImmutable() { + t.Errorf("expected old memtable to be immutable") + } + + // The flush pending flag should be reset + if pool.IsFlushNeeded() { + t.Errorf("expected flush pending to be reset after switch") + } + + // Now test age-based flushing + // Wait for the age threshold to trigger + time.Sleep(1200 * time.Millisecond) // Just over 1 second + + // Add a small amount of data to check conditions + pool.Put([]byte("trigger"), []byte("check"), 100) + + // Should trigger an age-based flush + if !pool.IsFlushNeeded() { + t.Errorf("expected flush needed after reaching age threshold") + } +} + +func TestMemPoolGetImmutablesForFlush(t *testing.T) { + cfg := createTestConfig() + pool := NewMemTablePool(cfg) + + // Switch memtables a few times to accumulate immutables + for i := 0; i < 3; i++ { + pool.Put([]byte{byte(i)}, []byte{byte(i)}, uint64(i+1)) + pool.SwitchToNewMemTable() + } + + // Should have 3 immutable memtables + if count := pool.ImmutableCount(); count != 3 { + t.Errorf("expected 3 immutable memtables, got %d", count) + } + + // Get immutables for flush + immutables := pool.GetImmutablesForFlush() + + // Should get all 3 immutables + if len(immutables) != 3 { + t.Errorf("expected to get 3 immutables for flush, got %d", len(immutables)) + } + + // The pool should now have 0 immutables + if count := pool.ImmutableCount(); count != 0 { + t.Errorf("expected 0 immutable memtables after flush, got %d", count) + } +} + +func TestMemPoolGetMemTables(t *testing.T) { + cfg := createTestConfig() + pool := NewMemTablePool(cfg) + + // Initially should have just the active memtable + tables := pool.GetMemTables() + if len(tables) != 1 { + t.Errorf("expected 1 memtable initially, got %d", len(tables)) + } + + // Add an immutable table + pool.Put([]byte("key"), []byte("value"), 1) + pool.SwitchToNewMemTable() + + // Now should have 2 memtables (active + 1 immutable) + tables = pool.GetMemTables() + if len(tables) != 2 { + t.Errorf("expected 2 memtables after switch, got %d", len(tables)) + } + + // The active table should be first + if tables[0].IsImmutable() { + t.Errorf("expected first table to be active (not immutable)") + } + + // The second table should be immutable + if !tables[1].IsImmutable() { + t.Errorf("expected second table to be immutable") + } +} + +func TestMemPoolGetNextSequenceNumber(t *testing.T) { + cfg := createTestConfig() + pool := NewMemTablePool(cfg) + + // Initial sequence number should be 0 + if seq := pool.GetNextSequenceNumber(); seq != 0 { + t.Errorf("expected initial sequence number to be 0, got %d", seq) + } + + // Add entries with sequence numbers + pool.Put([]byte("key"), []byte("value"), 5) + + // Next sequence number should be 6 + if seq := pool.GetNextSequenceNumber(); seq != 6 { + t.Errorf("expected sequence number to be 6, got %d", seq) + } + + // Switch to a new memtable + pool.SwitchToNewMemTable() + + // Sequence number should reset for the new table + if seq := pool.GetNextSequenceNumber(); seq != 0 { + t.Errorf("expected sequence number to reset to 0, got %d", seq) + } +} \ No newline at end of file diff --git a/pkg/memtable/memtable.go b/pkg/memtable/memtable.go new file mode 100644 index 0000000..0bb9df7 --- /dev/null +++ b/pkg/memtable/memtable.go @@ -0,0 +1,155 @@ +package memtable + +import ( + "sync" + "sync/atomic" + "time" + + "git.canoozie.net/jer/go-storage/pkg/wal" +) + +// MemTable is an in-memory table that stores key-value pairs +// It is implemented using a skip list for efficient inserts and lookups +type MemTable struct { + skipList *SkipList + nextSeqNum uint64 + creationTime time.Time + immutable atomic.Bool + size int64 + mu sync.RWMutex +} + +// NewMemTable creates a new memory table +func NewMemTable() *MemTable { + return &MemTable{ + skipList: NewSkipList(), + creationTime: time.Now(), + } +} + +// Put adds a key-value pair to the MemTable +func (m *MemTable) Put(key, value []byte, seqNum uint64) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.immutable.Load() { + // Don't modify immutable memtables + return + } + + e := newEntry(key, value, TypeValue, seqNum) + m.skipList.Insert(e) + + // Update maximum sequence number + if seqNum > m.nextSeqNum { + m.nextSeqNum = seqNum + 1 + } +} + +// Delete marks a key as deleted in the MemTable +func (m *MemTable) Delete(key []byte, seqNum uint64) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.immutable.Load() { + // Don't modify immutable memtables + return + } + + e := newEntry(key, nil, TypeDeletion, seqNum) + m.skipList.Insert(e) + + // Update maximum sequence number + if seqNum > m.nextSeqNum { + m.nextSeqNum = seqNum + 1 + } +} + +// Get retrieves the value associated with the given key +// Returns (nil, true) if the key exists but has been deleted +// Returns (nil, false) if the key does not exist +// Returns (value, true) if the key exists and has a value +func (m *MemTable) Get(key []byte) ([]byte, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + e := m.skipList.Find(key) + if e == nil { + return nil, false + } + + // Check if this is a deletion marker + if e.valueType == TypeDeletion { + return nil, true // Key exists but was deleted + } + + return e.value, true +} + +// Contains checks if the key exists in the MemTable +func (m *MemTable) Contains(key []byte) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.skipList.Find(key) != nil +} + +// ApproximateSize returns the approximate size of the MemTable in bytes +func (m *MemTable) ApproximateSize() int64 { + return m.skipList.ApproximateSize() +} + +// SetImmutable marks the MemTable as immutable +// After this is called, no more modifications are allowed +func (m *MemTable) SetImmutable() { + m.immutable.Store(true) +} + +// IsImmutable returns whether the MemTable is immutable +func (m *MemTable) IsImmutable() bool { + return m.immutable.Load() +} + +// Age returns the age of the MemTable in seconds +func (m *MemTable) Age() float64 { + return time.Since(m.creationTime).Seconds() +} + +// NewIterator returns an iterator for the MemTable +func (m *MemTable) NewIterator() *Iterator { + return m.skipList.NewIterator() +} + +// GetNextSequenceNumber returns the next sequence number to use +func (m *MemTable) GetNextSequenceNumber() uint64 { + m.mu.RLock() + defer m.mu.RUnlock() + return m.nextSeqNum +} + +// ProcessWALEntry processes a WAL entry and applies it to the MemTable +func (m *MemTable) ProcessWALEntry(entry *wal.Entry) error { + switch entry.Type { + case wal.OpTypePut: + m.Put(entry.Key, entry.Value, entry.SequenceNumber) + case wal.OpTypeDelete: + m.Delete(entry.Key, entry.SequenceNumber) + case wal.OpTypeBatch: + // Process batch operations + batch, err := wal.DecodeBatch(entry) + if err != nil { + return err + } + + for i, op := range batch.Operations { + seqNum := batch.Seq + uint64(i) + switch op.Type { + case wal.OpTypePut: + m.Put(op.Key, op.Value, seqNum) + case wal.OpTypeDelete: + m.Delete(op.Key, seqNum) + } + } + } + return nil +} \ No newline at end of file diff --git a/pkg/memtable/memtable_test.go b/pkg/memtable/memtable_test.go new file mode 100644 index 0000000..29fd86b --- /dev/null +++ b/pkg/memtable/memtable_test.go @@ -0,0 +1,202 @@ +package memtable + +import ( + "testing" + "time" + + "git.canoozie.net/jer/go-storage/pkg/wal" +) + +func TestMemTableBasicOperations(t *testing.T) { + mt := NewMemTable() + + // Test Put and Get + mt.Put([]byte("key1"), []byte("value1"), 1) + + value, found := mt.Get([]byte("key1")) + if !found { + t.Fatalf("expected to find key1, but got not found") + } + if string(value) != "value1" { + t.Errorf("expected value1, got %s", string(value)) + } + + // Test not found + _, found = mt.Get([]byte("nonexistent")) + if found { + t.Errorf("expected key 'nonexistent' to not be found") + } + + // Test Delete + mt.Delete([]byte("key1"), 2) + + value, found = mt.Get([]byte("key1")) + if !found { + t.Fatalf("expected tombstone to be found for key1") + } + if value != nil { + t.Errorf("expected nil value for deleted key, got %v", value) + } + + // Test Contains + if !mt.Contains([]byte("key1")) { + t.Errorf("expected Contains to return true for deleted key") + } + if mt.Contains([]byte("nonexistent")) { + t.Errorf("expected Contains to return false for nonexistent key") + } +} + +func TestMemTableSequenceNumbers(t *testing.T) { + mt := NewMemTable() + + // Add entries with sequence numbers + mt.Put([]byte("key"), []byte("value1"), 1) + mt.Put([]byte("key"), []byte("value2"), 3) + mt.Put([]byte("key"), []byte("value3"), 2) + + // Should get the latest by sequence number (value2) + value, found := mt.Get([]byte("key")) + if !found { + t.Fatalf("expected to find key, but got not found") + } + if string(value) != "value2" { + t.Errorf("expected value2 (highest seq), got %s", string(value)) + } + + // The next sequence number should be one more than the highest seen + if nextSeq := mt.GetNextSequenceNumber(); nextSeq != 4 { + t.Errorf("expected next sequence number to be 4, got %d", nextSeq) + } +} + +func TestMemTableImmutability(t *testing.T) { + mt := NewMemTable() + + // Add initial data + mt.Put([]byte("key"), []byte("value"), 1) + + // Mark as immutable + mt.SetImmutable() + if !mt.IsImmutable() { + t.Errorf("expected IsImmutable to return true after SetImmutable") + } + + // Attempts to modify should have no effect + mt.Put([]byte("key2"), []byte("value2"), 2) + mt.Delete([]byte("key"), 3) + + // Verify no changes occurred + _, found := mt.Get([]byte("key2")) + if found { + t.Errorf("expected key2 to not be added to immutable memtable") + } + + value, found := mt.Get([]byte("key")) + if !found { + t.Fatalf("expected to still find key after delete on immutable table") + } + if string(value) != "value" { + t.Errorf("expected value to remain unchanged, got %s", string(value)) + } +} + +func TestMemTableAge(t *testing.T) { + mt := NewMemTable() + + // A new memtable should have a very small age + if age := mt.Age(); age > 1.0 { + t.Errorf("expected new memtable to have age < 1.0s, got %.2fs", age) + } + + // Sleep to increase age + time.Sleep(10 * time.Millisecond) + + if age := mt.Age(); age <= 0.0 { + t.Errorf("expected memtable age to be > 0, got %.6fs", age) + } +} + +func TestMemTableWALIntegration(t *testing.T) { + mt := NewMemTable() + + // Create WAL entries + entries := []*wal.Entry{ + {SequenceNumber: 1, Type: wal.OpTypePut, Key: []byte("key1"), Value: []byte("value1")}, + {SequenceNumber: 2, Type: wal.OpTypeDelete, Key: []byte("key2"), Value: nil}, + {SequenceNumber: 3, Type: wal.OpTypePut, Key: []byte("key3"), Value: []byte("value3")}, + } + + // Process entries + for _, entry := range entries { + if err := mt.ProcessWALEntry(entry); err != nil { + t.Fatalf("failed to process WAL entry: %v", err) + } + } + + // Verify entries were processed correctly + testCases := []struct { + key string + expected string + found bool + }{ + {"key1", "value1", true}, + {"key2", "", true}, // Deleted key + {"key3", "value3", true}, + {"key4", "", false}, // Non-existent key + } + + for _, tc := range testCases { + value, found := mt.Get([]byte(tc.key)) + + if found != tc.found { + t.Errorf("key %s: expected found=%v, got %v", tc.key, tc.found, found) + continue + } + + if found && tc.expected != "" { + if string(value) != tc.expected { + t.Errorf("key %s: expected value '%s', got '%s'", tc.key, tc.expected, string(value)) + } + } + } + + // Verify next sequence number + if nextSeq := mt.GetNextSequenceNumber(); nextSeq != 4 { + t.Errorf("expected next sequence number to be 4, got %d", nextSeq) + } +} + +func TestMemTableIterator(t *testing.T) { + mt := NewMemTable() + + // Add entries in non-sorted order + entries := []struct { + key string + value string + seq uint64 + }{ + {"banana", "yellow", 1}, + {"apple", "red", 2}, + {"cherry", "red", 3}, + {"date", "brown", 4}, + } + + for _, e := range entries { + mt.Put([]byte(e.key), []byte(e.value), e.seq) + } + + // Use iterator to verify keys are returned in sorted order + it := mt.NewIterator() + it.SeekToFirst() + + expected := []string{"apple", "banana", "cherry", "date"} + + for i := 0; it.Valid() && i < len(expected); i++ { + key := string(it.Key()) + if key != expected[i] { + t.Errorf("position %d: expected key %s, got %s", i, expected[i], key) + } + it.Next() + } +} \ No newline at end of file diff --git a/pkg/memtable/recovery.go b/pkg/memtable/recovery.go new file mode 100644 index 0000000..d40e7b6 --- /dev/null +++ b/pkg/memtable/recovery.go @@ -0,0 +1,91 @@ +package memtable + +import ( + "fmt" + + "git.canoozie.net/jer/go-storage/pkg/config" + "git.canoozie.net/jer/go-storage/pkg/wal" +) + +// RecoveryOptions contains options for MemTable recovery +type RecoveryOptions struct { + // MaxSequenceNumber is the maximum sequence number to recover + // Entries with sequence numbers greater than this will be ignored + MaxSequenceNumber uint64 + + // MaxMemTables is the maximum number of MemTables to create during recovery + // If more MemTables would be needed, an error is returned + MaxMemTables int + + // MemTableSize is the maximum size of each MemTable + MemTableSize int64 +} + +// DefaultRecoveryOptions returns the default recovery options +func DefaultRecoveryOptions(cfg *config.Config) *RecoveryOptions { + return &RecoveryOptions{ + MaxSequenceNumber: ^uint64(0), // Max uint64 + MaxMemTables: cfg.MaxMemTables, + MemTableSize: cfg.MemTableSize, + } +} + +// RecoverFromWAL rebuilds MemTables from the write-ahead log +// Returns a list of recovered MemTables and the maximum sequence number seen +func RecoverFromWAL(cfg *config.Config, opts *RecoveryOptions) ([]*MemTable, uint64, error) { + if opts == nil { + opts = DefaultRecoveryOptions(cfg) + } + + // Create the first MemTable + memTables := []*MemTable{NewMemTable()} + var maxSeqNum uint64 + + // Function to process each WAL entry + entryHandler := func(entry *wal.Entry) error { + // Skip entries with sequence numbers beyond our max + if entry.SequenceNumber > opts.MaxSequenceNumber { + return nil + } + + // Update the max sequence number + if entry.SequenceNumber > maxSeqNum { + maxSeqNum = entry.SequenceNumber + } + + // Get the current memtable + current := memTables[len(memTables)-1] + + // Check if we should create a new memtable based on size + if current.ApproximateSize() >= opts.MemTableSize { + // Make sure we don't exceed the max number of memtables + if len(memTables) >= opts.MaxMemTables { + return fmt.Errorf("maximum number of memtables (%d) exceeded during recovery", opts.MaxMemTables) + } + + // Mark the current memtable as immutable + current.SetImmutable() + + // Create a new memtable + current = NewMemTable() + memTables = append(memTables, current) + } + + // Process the entry + return current.ProcessWALEntry(entry) + } + + // Replay the WAL directory + if err := wal.ReplayWALDir(cfg.WALDir, entryHandler); err != nil { + return nil, 0, fmt.Errorf("failed to replay WAL: %w", err) + } + + // For batch operations, we need to adjust maxSeqNum + finalTable := memTables[len(memTables)-1] + nextSeq := finalTable.GetNextSequenceNumber() + if nextSeq > maxSeqNum+1 { + maxSeqNum = nextSeq - 1 + } + + return memTables, maxSeqNum, nil +} \ No newline at end of file diff --git a/pkg/memtable/recovery_test.go b/pkg/memtable/recovery_test.go new file mode 100644 index 0000000..e66ef2e --- /dev/null +++ b/pkg/memtable/recovery_test.go @@ -0,0 +1,276 @@ +package memtable + +import ( + "os" + "testing" + + "git.canoozie.net/jer/go-storage/pkg/config" + "git.canoozie.net/jer/go-storage/pkg/wal" +) + +func setupTestWAL(t *testing.T) (string, *wal.WAL, func()) { + // Create temporary directory + tmpDir, err := os.MkdirTemp("", "memtable_recovery_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + + // Create config + cfg := config.NewDefaultConfig(tmpDir) + + // Create WAL + w, err := wal.NewWAL(cfg, tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("failed to create WAL: %v", err) + } + + // Return cleanup function + cleanup := func() { + w.Close() + os.RemoveAll(tmpDir) + } + + return tmpDir, w, cleanup +} + +func TestRecoverFromWAL(t *testing.T) { + tmpDir, w, cleanup := setupTestWAL(t) + defer cleanup() + + // Add entries to the WAL + entries := []struct { + opType uint8 + key string + value string + }{ + {wal.OpTypePut, "key1", "value1"}, + {wal.OpTypePut, "key2", "value2"}, + {wal.OpTypeDelete, "key1", ""}, + {wal.OpTypePut, "key3", "value3"}, + } + + for _, e := range entries { + var seq uint64 + var err error + + if e.opType == wal.OpTypePut { + seq, err = w.Append(e.opType, []byte(e.key), []byte(e.value)) + } else { + seq, err = w.Append(e.opType, []byte(e.key), nil) + } + + if err != nil { + t.Fatalf("failed to append to WAL: %v", err) + } + t.Logf("Appended entry with seq %d", seq) + } + + // Sync and close WAL + if err := w.Sync(); err != nil { + t.Fatalf("failed to sync WAL: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("failed to close WAL: %v", err) + } + + // Create config for recovery + cfg := config.NewDefaultConfig(tmpDir) + cfg.WALDir = tmpDir + cfg.MemTableSize = 1024 * 1024 // 1MB + + // Recover memtables from WAL + memTables, maxSeq, err := RecoverFromWAL(cfg, nil) + if err != nil { + t.Fatalf("failed to recover from WAL: %v", err) + } + + // Validate recovery results + if len(memTables) == 0 { + t.Fatalf("expected at least one memtable from recovery") + } + + t.Logf("Recovered %d memtables with max sequence %d", len(memTables), maxSeq) + + // The max sequence number should be 4 + if maxSeq != 4 { + t.Errorf("expected max sequence number 4, got %d", maxSeq) + } + + // Validate content of the recovered memtable + mt := memTables[0] + + // key1 should be deleted + value, found := mt.Get([]byte("key1")) + if !found { + t.Errorf("expected key1 to be found (as deleted)") + } + if value != nil { + t.Errorf("expected key1 to have nil value (deleted), got %v", value) + } + + // key2 should have "value2" + value, found = mt.Get([]byte("key2")) + if !found { + t.Errorf("expected key2 to be found") + } else if string(value) != "value2" { + t.Errorf("expected key2 to have value 'value2', got '%s'", string(value)) + } + + // key3 should have "value3" + value, found = mt.Get([]byte("key3")) + if !found { + t.Errorf("expected key3 to be found") + } else if string(value) != "value3" { + t.Errorf("expected key3 to have value 'value3', got '%s'", string(value)) + } +} + +func TestRecoveryWithMultipleMemTables(t *testing.T) { + tmpDir, w, cleanup := setupTestWAL(t) + defer cleanup() + + // Create a lot of large entries to force multiple memtables + largeValue := make([]byte, 1000) // 1KB value + for i := 0; i < 10; i++ { + key := []byte{byte(i + 'a')} + if _, err := w.Append(wal.OpTypePut, key, largeValue); err != nil { + t.Fatalf("failed to append to WAL: %v", err) + } + } + + // Sync and close WAL + if err := w.Sync(); err != nil { + t.Fatalf("failed to sync WAL: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("failed to close WAL: %v", err) + } + + // Create config for recovery with small memtable size + cfg := config.NewDefaultConfig(tmpDir) + cfg.WALDir = tmpDir + cfg.MemTableSize = 5 * 1000 // 5KB - should fit about 5 entries + cfg.MaxMemTables = 3 // Allow up to 3 memtables + + // Recover memtables from WAL + memTables, _, err := RecoverFromWAL(cfg, nil) + if err != nil { + t.Fatalf("failed to recover from WAL: %v", err) + } + + // Should have created multiple memtables + if len(memTables) <= 1 { + t.Errorf("expected multiple memtables due to size, got %d", len(memTables)) + } + + t.Logf("Recovered %d memtables", len(memTables)) + + // All memtables except the last one should be immutable + for i, mt := range memTables[:len(memTables)-1] { + if !mt.IsImmutable() { + t.Errorf("expected memtable %d to be immutable", i) + } + } + + // Verify all data was recovered across all memtables + for i := 0; i < 10; i++ { + key := []byte{byte(i + 'a')} + found := false + + // Check each memtable for the key + for _, mt := range memTables { + if _, exists := mt.Get(key); exists { + found = true + break + } + } + + if !found { + t.Errorf("key %c not found in any memtable", i+'a') + } + } +} + +func TestRecoveryWithBatchOperations(t *testing.T) { + tmpDir, w, cleanup := setupTestWAL(t) + defer cleanup() + + // Create a batch of operations + batch := wal.NewBatch() + batch.Put([]byte("batch_key1"), []byte("batch_value1")) + batch.Put([]byte("batch_key2"), []byte("batch_value2")) + batch.Delete([]byte("batch_key3")) + + // Write the batch to the WAL + if err := batch.Write(w); err != nil { + t.Fatalf("failed to write batch to WAL: %v", err) + } + + // Add some individual operations too + if _, err := w.Append(wal.OpTypePut, []byte("key4"), []byte("value4")); err != nil { + t.Fatalf("failed to append to WAL: %v", err) + } + + // Sync and close WAL + if err := w.Sync(); err != nil { + t.Fatalf("failed to sync WAL: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("failed to close WAL: %v", err) + } + + // Create config for recovery + cfg := config.NewDefaultConfig(tmpDir) + cfg.WALDir = tmpDir + + // Recover memtables from WAL + memTables, maxSeq, err := RecoverFromWAL(cfg, nil) + if err != nil { + t.Fatalf("failed to recover from WAL: %v", err) + } + + if len(memTables) == 0 { + t.Fatalf("expected at least one memtable from recovery") + } + + // The max sequence number should account for batch operations + if maxSeq < 3 { // At least 3 from batch + individual op + t.Errorf("expected max sequence number >= 3, got %d", maxSeq) + } + + // Validate content of the recovered memtable + mt := memTables[0] + + // Check batch keys were recovered + value, found := mt.Get([]byte("batch_key1")) + if !found { + t.Errorf("batch_key1 not found in recovered memtable") + } else if string(value) != "batch_value1" { + t.Errorf("expected batch_key1 to have value 'batch_value1', got '%s'", string(value)) + } + + value, found = mt.Get([]byte("batch_key2")) + if !found { + t.Errorf("batch_key2 not found in recovered memtable") + } else if string(value) != "batch_value2" { + t.Errorf("expected batch_key2 to have value 'batch_value2', got '%s'", string(value)) + } + + // batch_key3 should be marked as deleted + value, found = mt.Get([]byte("batch_key3")) + if !found { + t.Errorf("expected batch_key3 to be found as deleted") + } + if value != nil { + t.Errorf("expected batch_key3 to have nil value (deleted), got %v", value) + } + + // Check individual operation was recovered + value, found = mt.Get([]byte("key4")) + if !found { + t.Errorf("key4 not found in recovered memtable") + } else if string(value) != "value4" { + t.Errorf("expected key4 to have value 'value4', got '%s'", string(value)) + } +} \ No newline at end of file diff --git a/pkg/memtable/skiplist.go b/pkg/memtable/skiplist.go new file mode 100644 index 0000000..5c8e33d --- /dev/null +++ b/pkg/memtable/skiplist.go @@ -0,0 +1,308 @@ +package memtable + +import ( + "bytes" + "math/rand" + "sync" + "sync/atomic" + "time" + "unsafe" +) + +const ( + // MaxHeight is the maximum height of the skip list + MaxHeight = 12 + + // BranchingFactor determines the probability of increasing the height + BranchingFactor = 4 + + // DefaultCacheLineSize aligns nodes to cache lines for better performance + DefaultCacheLineSize = 64 +) + +// ValueType represents the type of a key-value entry +type ValueType uint8 + +const ( + // TypeValue indicates the entry contains a value + TypeValue ValueType = iota + 1 + + // TypeDeletion indicates the entry is a tombstone (deletion marker) + TypeDeletion +) + +// entry represents a key-value pair with additional metadata +type entry struct { + key []byte + value []byte + valueType ValueType + seqNum uint64 +} + +// newEntry creates a new entry +func newEntry(key, value []byte, valueType ValueType, seqNum uint64) *entry { + return &entry{ + key: key, + value: value, + valueType: valueType, + seqNum: seqNum, + } +} + +// size returns the approximate size of the entry in memory +func (e *entry) size() int { + return len(e.key) + len(e.value) + 16 // adding overhead for metadata +} + +// compare compares this entry with another key +// Returns: negative if e.key < key, 0 if equal, positive if e.key > key +func (e *entry) compare(key []byte) int { + return bytes.Compare(e.key, key) +} + +// compareWithEntry compares this entry with another entry +// First by key, then by sequence number (in reverse order to prioritize newer entries) +func (e *entry) compareWithEntry(other *entry) int { + cmp := bytes.Compare(e.key, other.key) + if cmp == 0 { + // If keys are equal, compare sequence numbers in reverse order (newer first) + if e.seqNum > other.seqNum { + return -1 + } else if e.seqNum < other.seqNum { + return 1 + } + return 0 + } + return cmp +} + +// node represents a node in the skip list +type node struct { + entry *entry + height int32 + // next contains pointers to the next nodes at each level + // This is allocated as a single block for cache efficiency + next [MaxHeight]unsafe.Pointer +} + +// newNode creates a new node with a random height +func newNode(e *entry, height int) *node { + return &node{ + entry: e, + height: int32(height), + } +} + +// getNext returns the next node at the given level +func (n *node) getNext(level int) *node { + return (*node)(atomic.LoadPointer(&n.next[level])) +} + +// setNext sets the next node at the given level +func (n *node) setNext(level int, next *node) { + atomic.StorePointer(&n.next[level], unsafe.Pointer(next)) +} + +// SkipList is a concurrent skip list implementation for the MemTable +type SkipList struct { + head *node + maxHeight int32 + rnd *rand.Rand + rndMtx sync.Mutex + size int64 +} + +// NewSkipList creates a new skip list +func NewSkipList() *SkipList { + seed := time.Now().UnixNano() + list := &SkipList{ + head: newNode(nil, MaxHeight), + maxHeight: 1, + rnd: rand.New(rand.NewSource(seed)), + } + return list +} + +// randomHeight generates a random height for a new node +func (s *SkipList) randomHeight() int { + s.rndMtx.Lock() + defer s.rndMtx.Unlock() + + height := 1 + for height < MaxHeight && s.rnd.Intn(BranchingFactor) == 0 { + height++ + } + return height +} + +// getCurrentHeight returns the current maximum height of the skip list +func (s *SkipList) getCurrentHeight() int { + return int(atomic.LoadInt32(&s.maxHeight)) +} + +// Insert adds a new entry to the skip list +func (s *SkipList) Insert(e *entry) { + height := s.randomHeight() + prev := [MaxHeight]*node{} + node := newNode(e, height) + + // Try to increase the height of the list + currHeight := s.getCurrentHeight() + if height > currHeight { + // Attempt to increase the height + if atomic.CompareAndSwapInt32(&s.maxHeight, int32(currHeight), int32(height)) { + currHeight = height + } + } + + // Find where to insert at each level + current := s.head + for level := currHeight - 1; level >= 0; level-- { + // Find the insertion point at this level + for next := current.getNext(level); next != nil; next = current.getNext(level) { + if next.entry.compareWithEntry(e) >= 0 { + break + } + current = next + } + prev[level] = current + } + + // Insert the node at each level + for level := 0; level < height; level++ { + node.setNext(level, prev[level].getNext(level)) + prev[level].setNext(level, node) + } + + // Update approximate size + atomic.AddInt64(&s.size, int64(e.size())) +} + +// Find looks for an entry with the specified key +// If multiple entries have the same key, the most recent one is returned +func (s *SkipList) Find(key []byte) *entry { + var result *entry + current := s.head + height := s.getCurrentHeight() + + // Start from the highest level for efficient search + for level := height - 1; level >= 0; level-- { + // Scan forward until we find a key greater than or equal to the target + for next := current.getNext(level); next != nil; next = current.getNext(level) { + cmp := next.entry.compare(key) + if cmp > 0 { + // Key at next is greater than target, go down a level + break + } else if cmp == 0 { + // Found a match, check if it's newer than our current result + if result == nil || next.entry.seqNum > result.seqNum { + result = next.entry + } + // Continue at this level to see if there are more entries with same key + current = next + } else { + // Key at next is less than target, move forward + current = next + } + } + } + + // For level 0, do one more sweep to ensure we get the newest entry + current = s.head + for next := current.getNext(0); next != nil; next = next.getNext(0) { + cmp := next.entry.compare(key) + if cmp > 0 { + // Past the key + break + } else if cmp == 0 { + // Found a match, update result if it's newer + if result == nil || next.entry.seqNum > result.seqNum { + result = next.entry + } + } + current = next + } + + return result +} + +// ApproximateSize returns the approximate size of the skip list in bytes +func (s *SkipList) ApproximateSize() int64 { + return atomic.LoadInt64(&s.size) +} + +// Iterator provides sequential access to the skip list entries +type Iterator struct { + list *SkipList + current *node +} + +// NewIterator creates a new Iterator for the skip list +func (s *SkipList) NewIterator() *Iterator { + return &Iterator{ + list: s, + current: s.head, + } +} + +// Valid returns true if the iterator is positioned at a valid entry +func (it *Iterator) Valid() bool { + return it.current != nil && it.current != it.list.head +} + +// Next advances the iterator to the next entry +func (it *Iterator) Next() { + if it.current == nil { + return + } + it.current = it.current.getNext(0) +} + +// SeekToFirst positions the iterator at the first entry +func (it *Iterator) SeekToFirst() { + it.current = it.list.head.getNext(0) +} + +// Seek positions the iterator at the first entry with a key >= target +func (it *Iterator) Seek(key []byte) { + // Start from head + current := it.list.head + height := it.list.getCurrentHeight() + + // Search algorithm similar to Find + for level := height - 1; level >= 0; level-- { + for next := current.getNext(level); next != nil; next = current.getNext(level) { + if next.entry.compare(key) >= 0 { + break + } + current = next + } + } + + // Move to the next node, which should be >= target + it.current = current.getNext(0) +} + +// Key returns the key of the current entry +func (it *Iterator) Key() []byte { + if !it.Valid() { + return nil + } + return it.current.entry.key +} + +// Value returns the value of the current entry +func (it *Iterator) Value() []byte { + if !it.Valid() { + return nil + } + return it.current.entry.value +} + +// Entry returns the current entry +func (it *Iterator) Entry() *entry { + if !it.Valid() { + return nil + } + return it.current.entry +} \ No newline at end of file diff --git a/pkg/memtable/skiplist_test.go b/pkg/memtable/skiplist_test.go new file mode 100644 index 0000000..09992fe --- /dev/null +++ b/pkg/memtable/skiplist_test.go @@ -0,0 +1,232 @@ +package memtable + +import ( + "bytes" + "testing" +) + +func TestSkipListBasicOperations(t *testing.T) { + sl := NewSkipList() + + // Test insertion + e1 := newEntry([]byte("key1"), []byte("value1"), TypeValue, 1) + e2 := newEntry([]byte("key2"), []byte("value2"), TypeValue, 2) + e3 := newEntry([]byte("key3"), []byte("value3"), TypeValue, 3) + + sl.Insert(e1) + sl.Insert(e2) + sl.Insert(e3) + + // Test lookup + found := sl.Find([]byte("key2")) + if found == nil { + t.Fatalf("expected to find key2, but got nil") + } + if string(found.value) != "value2" { + t.Errorf("expected value to be 'value2', got '%s'", string(found.value)) + } + + // Test lookup of non-existent key + notFound := sl.Find([]byte("key4")) + if notFound != nil { + t.Errorf("expected nil for non-existent key, got %v", notFound) + } +} + +func TestSkipListSequenceNumbers(t *testing.T) { + sl := NewSkipList() + + // Insert same key with different sequence numbers + e1 := newEntry([]byte("key"), []byte("value1"), TypeValue, 1) + e2 := newEntry([]byte("key"), []byte("value2"), TypeValue, 2) + e3 := newEntry([]byte("key"), []byte("value3"), TypeValue, 3) + + // Insert in reverse order to test ordering + sl.Insert(e3) + sl.Insert(e2) + sl.Insert(e1) + + // Find should return the entry with the highest sequence number + found := sl.Find([]byte("key")) + if found == nil { + t.Fatalf("expected to find key, but got nil") + } + if string(found.value) != "value3" { + t.Errorf("expected value to be 'value3' (highest seq num), got '%s'", string(found.value)) + } + if found.seqNum != 3 { + t.Errorf("expected sequence number to be 3, got %d", found.seqNum) + } +} + +func TestSkipListIterator(t *testing.T) { + sl := NewSkipList() + + // Insert entries + entries := []struct { + key string + value string + seq uint64 + }{ + {"apple", "red", 1}, + {"banana", "yellow", 2}, + {"cherry", "red", 3}, + {"date", "brown", 4}, + {"elderberry", "purple", 5}, + } + + for _, e := range entries { + sl.Insert(newEntry([]byte(e.key), []byte(e.value), TypeValue, e.seq)) + } + + // Test iteration + it := sl.NewIterator() + it.SeekToFirst() + + count := 0 + for it.Valid() { + if count >= len(entries) { + t.Fatalf("iterator returned more entries than expected") + } + + expectedKey := entries[count].key + expectedValue := entries[count].value + + if string(it.Key()) != expectedKey { + t.Errorf("at position %d, expected key '%s', got '%s'", count, expectedKey, string(it.Key())) + } + if string(it.Value()) != expectedValue { + t.Errorf("at position %d, expected value '%s', got '%s'", count, expectedValue, string(it.Value())) + } + + it.Next() + count++ + } + + if count != len(entries) { + t.Errorf("expected to iterate through %d entries, but got %d", len(entries), count) + } +} + +func TestSkipListSeek(t *testing.T) { + sl := NewSkipList() + + // Insert entries + entries := []struct { + key string + value string + seq uint64 + }{ + {"apple", "red", 1}, + {"banana", "yellow", 2}, + {"cherry", "red", 3}, + {"date", "brown", 4}, + {"elderberry", "purple", 5}, + } + + for _, e := range entries { + sl.Insert(newEntry([]byte(e.key), []byte(e.value), TypeValue, e.seq)) + } + + testCases := []struct { + seek string + expected string + valid bool + }{ + // Before first entry + {"a", "apple", true}, + // Exact match + {"cherry", "cherry", true}, + // Between entries + {"blueberry", "cherry", true}, + // After last entry + {"zebra", "", false}, + } + + for _, tc := range testCases { + t.Run(tc.seek, func(t *testing.T) { + it := sl.NewIterator() + it.Seek([]byte(tc.seek)) + + if it.Valid() != tc.valid { + t.Errorf("expected Valid() to be %v, got %v", tc.valid, it.Valid()) + } + + if tc.valid { + if string(it.Key()) != tc.expected { + t.Errorf("expected key '%s', got '%s'", tc.expected, string(it.Key())) + } + } + }) + } +} + +func TestEntryComparison(t *testing.T) { + testCases := []struct { + e1, e2 *entry + expected int + }{ + // Different keys + { + newEntry([]byte("a"), []byte("val"), TypeValue, 1), + newEntry([]byte("b"), []byte("val"), TypeValue, 1), + -1, + }, + { + newEntry([]byte("b"), []byte("val"), TypeValue, 1), + newEntry([]byte("a"), []byte("val"), TypeValue, 1), + 1, + }, + // Same key, different sequence numbers (higher seq should be "less") + { + newEntry([]byte("same"), []byte("val1"), TypeValue, 2), + newEntry([]byte("same"), []byte("val2"), TypeValue, 1), + -1, + }, + { + newEntry([]byte("same"), []byte("val1"), TypeValue, 1), + newEntry([]byte("same"), []byte("val2"), TypeValue, 2), + 1, + }, + // Same key, same sequence number + { + newEntry([]byte("same"), []byte("val"), TypeValue, 1), + newEntry([]byte("same"), []byte("val"), TypeValue, 1), + 0, + }, + } + + for i, tc := range testCases { + result := tc.e1.compareWithEntry(tc.e2) + expected := tc.expected + // We just care about the sign + if (result < 0 && expected >= 0) || (result > 0 && expected <= 0) || (result == 0 && expected != 0) { + t.Errorf("case %d: expected comparison result %d, got %d", i, expected, result) + } + } +} + +func TestSkipListApproximateSize(t *testing.T) { + sl := NewSkipList() + + // Initial size should be 0 + if size := sl.ApproximateSize(); size != 0 { + t.Errorf("expected initial size to be 0, got %d", size) + } + + // Add some entries + e1 := newEntry([]byte("key1"), []byte("value1"), TypeValue, 1) + e2 := newEntry([]byte("key2"), bytes.Repeat([]byte("v"), 100), TypeValue, 2) + + sl.Insert(e1) + expectedSize := int64(e1.size()) + if size := sl.ApproximateSize(); size != expectedSize { + t.Errorf("expected size to be %d after first insert, got %d", expectedSize, size) + } + + sl.Insert(e2) + expectedSize += int64(e2.size()) + if size := sl.ApproximateSize(); size != expectedSize { + t.Errorf("expected size to be %d after second insert, got %d", expectedSize, size) + } +} \ No newline at end of file